#include "pkcs11_handler.h"

void increment_byte_array(unsigned char *array, int array_length) {
    int carry = 1;
    for (int i = array_length - 1; i >= 0 && carry == 1; i--) {
        array[i] += carry;
        carry = array[i] == 0 ? 1 : 0;
    }
}

void xor_byte_array(unsigned char *a_array, const unsigned char *b_array, unsigned long array_length) {
    for (unsigned long i = 0; i < array_length; i++) {
        a_array[i] ^= b_array[i];
    }
}

CK_MECHANISM get_aes_ecb_mechanism() {
    CK_MECHANISM aesEncryptMechanism;
    aesEncryptMechanism.mechanism = CKM_AES_ECB;
    aesEncryptMechanism.pParameter = NULL;
    aesEncryptMechanism.ulParameterLen = 0;
    return aesEncryptMechanism;
}

CK_MECHANISM get_aes_cbc_mechanism() {
    CK_MECHANISM aesEncryptMechanism;
    aesEncryptMechanism.mechanism = CKM_AES_CBC;
    aesEncryptMechanism.pParameter = calloc(AES_BLOCK_SIZE, 1);
    aesEncryptMechanism.ulParameterLen = AES_BLOCK_SIZE;
    return aesEncryptMechanism;
}

int perform_aes_key_derivation(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len,
                               unsigned char *out, unsigned long *out_len, const unsigned char *iv) {
    unsigned char aes_client_server_key[TLS_CRYPT_V2_KEY_LEN] = {0};
    unsigned long key_len = TLS_CRYPT_V2_KEY_LEN;
    int exit_code = false;

    CRYPTO_ECHECK(!perform_sha_256_hmac_keyfreedom(pkcs_ctx,
                                                   iv, TLS_CRYPT_V2_TAG_LEN,
                                                   aes_client_server_key, &key_len,
                                                   pkcs_ctx->aes_key),
                  "perform_sha_256_hmac_keyfreedom() failed");

    EVP_CIPHER_CTX *evp_ctx = EVP_CIPHER_CTX_new();
    CRYPTO_ECHECK(!evp_ctx,
                  "EVP_CIPHER_CTX_new() failed");

    CRYPTO_ECHECK(!EVP_CIPHER_CTX_init(evp_ctx),
                  "EVP_CIPHER_CTX_init() failed");
    CRYPTO_ECHECK(!EVP_EncryptInit_ex2(evp_ctx, EVP_aes_256_ctr(), aes_client_server_key, iv, NULL),
                  "EVP_EncryptInit_ex2() failed");
    CRYPTO_ECHECK(!EVP_EncryptUpdate(evp_ctx, out, (int *) out_len, data, (int) data_len),
                  "EVP_EncryptUpdate() failed");
    EVP_CIPHER_CTX_free(evp_ctx);

    exit_code = true;

error_exit:
    ovpn_secure_memzero(aes_client_server_key, sizeof(aes_client_server_key));
    return exit_code;
}

int perform_aes_with_ecb(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len, unsigned char *out,
                unsigned long *out_len, const unsigned char *nonce) {
    int block_count = ceil(((double) data_len) / ((double) AES_BLOCK_SIZE));
    unsigned char iv[AES_BLOCK_SIZE];
    unsigned char iv_sequence[AES_BLOCK_SIZE * block_count];
    memcpy(iv, nonce, AES_BLOCK_SIZE);
    memcpy(iv_sequence, nonce, AES_BLOCK_SIZE);
    memcpy(out, data, data_len);
    for(int block_index = 1; block_index < block_count; ++block_index)
    {
        increment_byte_array(iv, AES_BLOCK_SIZE);
        memcpy(iv_sequence + block_index * AES_BLOCK_SIZE, iv, AES_BLOCK_SIZE);
    }

    ERROR_CHECK(! encrypt_block_n(pkcs_ctx, iv_sequence, iv_sequence, pkcs_ctx->aes_key, block_count),
                "encrypt_block_n() failed");

    xor_byte_array(out, iv_sequence, data_len);
    *out_len = data_len;

    return true;
}

int perform_aes(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len, unsigned char *out,
                unsigned long *out_len, const unsigned char *nonce) {
    unsigned char encrypted_iv[AES_BLOCK_SIZE];
    unsigned char iv[AES_BLOCK_SIZE];
    memcpy(iv, nonce, AES_BLOCK_SIZE);
    memcpy(out, data, data_len);
    int block_size, processed;
    for (processed = 0; processed < data_len; processed += block_size) {
        block_size = (data_len - processed < AES_BLOCK_SIZE) ? (int) data_len - processed : AES_BLOCK_SIZE;
        ERROR_CHECK(! encrypt_block_n(pkcs_ctx, iv, encrypted_iv, pkcs_ctx->aes_key, 1),
                    "encrypt_block_n() failed");
        xor_byte_array(out, encrypted_iv, block_size);
        increment_byte_array(iv, AES_BLOCK_SIZE);
    }

    *out_len = processed;

    return true;
}

int perform_aes_256_ctr(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len,
                        unsigned char *out, unsigned long *out_len, const unsigned char *nonce) {
    ERROR_CHECK(!pkcs_ctx->aes_key,
                "AES Key not loaded");

    CK_MECHANISM aesEncryptMechanism;
    CK_AES_CTR_PARAMS params;
    memcpy(params.cb, nonce, TLS_CRYPT_V2_IV_LENGTH);
    params.ulCounterBits = TLS_CRYPT_V2_IV_LENGTH * 8;
    aesEncryptMechanism.mechanism = CKM_AES_CTR;
    aesEncryptMechanism.pParameter = &params;
    aesEncryptMechanism.ulParameterLen = sizeof(params);

    ERROR_CHECK(pkcs_ctx->p11_functions->C_EncryptInit(pkcs_ctx->session, &aesEncryptMechanism, pkcs_ctx->aes_key),
                "C_EncryptInit() failed");

    ERROR_CHECK(pkcs_ctx->p11_functions->C_Encrypt(pkcs_ctx->session, (CK_BYTE_PTR) data, data_len, out, out_len),
                "C_Encrypt() failed");

    return true;
}

int
encrypt_block_n(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned char *out,
                CK_OBJECT_HANDLE key_handle, int block_count) {
    ERROR_CHECK(pkcs_ctx->p11_functions->C_EncryptInit(pkcs_ctx->session, &pkcs_ctx->aes_mechanism, key_handle),
                "C_EncryptInit() failed");


    CK_ULONG out_length = AES_BLOCK_SIZE * block_count;
    ERROR_CHECK(pkcs_ctx->p11_functions->C_Encrypt(pkcs_ctx->session, (CK_BYTE_PTR) data, AES_BLOCK_SIZE * block_count, out,
                                             &out_length),
                "C_Encrypt() failed");

    return true;
}

int perform_hmac_key_derivation(struct PKCS_Context *ctx, const unsigned char *data, unsigned long data_len,
                                unsigned char *out, unsigned long *out_len) {
    unsigned char hmac_client_server_key[TLS_CRYPT_V2_KEY_LEN] = {0};
    unsigned char data_buffer[TLS_CRYPT_V2_MAX_WKC_LEN] = {0};
    int exit_code = false;

    encrypt_block_n(ctx, data, hmac_client_server_key, ctx->hmac_key, 2);

    data_buffer[0] = (data_len >> 8) & 0xFF;
    data_buffer[1] = data_len & 0xFF;
    memcpy(data_buffer + 2, data, data_len);

    CRYPTO_ECHECK(!HMAC(EVP_sha256(),
                        hmac_client_server_key, TLS_CRYPT_V2_KEY_LEN,
                        data_buffer, data_len + TLS_CRYPT_V2_LEN_LEN,
                        out, (unsigned int *) out_len),
                  "HMAC() failed");
    exit_code = true;
    error_exit:
    ovpn_secure_memzero(hmac_client_server_key, sizeof(hmac_client_server_key));
    return exit_code;
}

int perform_sha_256_hmac_keyfreedom(struct PKCS_Context *ctx, const unsigned char *data, unsigned long data_len,
                                    unsigned char *out, unsigned long *out_len, CK_OBJECT_HANDLE key_handle) {
    ERROR_CHECK(!ctx->hmac_key, "HMAC Key not loaded");

    CK_MECHANISM hmacMechanism;
    hmacMechanism.mechanism = CKM_SHA256_HMAC;

    ERROR_CHECK(ctx->p11_functions->C_SignInit(ctx->session, &hmacMechanism, key_handle),
                "C_SignInit() failed");

    ERROR_CHECK(ctx->p11_functions->C_Sign(ctx->session, (CK_BYTE_PTR) data, data_len, out, out_len),
                "C_Sign() failed");

    return true;
}

int
perform_sha_256_hmac(struct PKCS_Context *ctx, const unsigned char *data, unsigned long data_len, unsigned char *out,
                     unsigned long *out_len) {
    return perform_sha_256_hmac_keyfreedom(ctx, data, data_len, out, out_len, ctx->hmac_key);
}

int find_key(struct PKCS_Context *ctx, CK_OBJECT_HANDLE *key, CK_ATTRIBUTE *attr_list, ulong attr_list_len) {
    CK_ULONG ulObjectCount;

    ERROR_CHECK(ctx->p11_functions->C_FindObjectsInit(ctx->session, attr_list, attr_list_len),
                "C_FindObjectsInit() failed");

    ERROR_CHECK(ctx->p11_functions->C_FindObjects(ctx->session, key, MAX_OBJECT_COUNT, &ulObjectCount),
                "C_FindObjects() failed");

    ERROR_CHECK(ctx->p11_functions->C_FindObjectsFinal(ctx->session),
                "C_FindObjectsFinal() failed");

    if (ulObjectCount != 1) {
        plog(PLOG_DEBUG, "No key was found");
        return false;
    }

    return true;
}

int find_key_by_label(struct PKCS_Context *pkcs_ctx, CK_OBJECT_HANDLE *key_handle, char *label) {
    CK_BBOOL true_var = CK_TRUE;
    ulong value_len = TLS_CRYPT_V2_KEY_LEN;
    CK_ATTRIBUTE attr_list[] = {
            {CKA_TOKEN,     &true_var,  sizeof(true_var)},
            {CKA_LABEL,     label,      strlen(label)},
            {CKA_VALUE_LEN, &value_len, sizeof(value_len)}
    };
    return find_key(pkcs_ctx, key_handle, attr_list, sizeof(attr_list) / sizeof(attr_list[0]));
}

int find_aes_key(struct PKCS_Context *ctx, CK_OBJECT_HANDLE *key) {
    return find_key_by_label(ctx, key, AES_KEY_LABEL);
}

int find_hmac_key(struct PKCS_Context *ctx, CK_OBJECT_HANDLE *key) {

    return find_key_by_label(ctx, key, HMAC_KEY_LABEL);
}

int find_keys(struct plugin_ctx *plugin_ctx) {
    return find_aes_key(plugin_ctx->pkcs_ctx, &plugin_ctx->pkcs_ctx->aes_key) &&
           find_hmac_key(plugin_ctx->pkcs_ctx, &plugin_ctx->pkcs_ctx->hmac_key);
}

int delete_aes_key(struct PKCS_Context *ctx) {
    CK_OBJECT_HANDLE aes_key;
    // Try to delete previous keys
    while (find_aes_key(ctx, &aes_key))
    {
        ERROR_CHECK(ctx->p11_functions->C_DestroyObject(ctx->session, aes_key),
                    "C_DestroyObject() failed");
    }
    return true;
}

int delete_hmac_key(struct PKCS_Context *ctx) {
    CK_OBJECT_HANDLE hmac_key;
    while (find_hmac_key(ctx, &hmac_key))
    {
        ERROR_CHECK(ctx->p11_functions->C_DestroyObject(ctx->session, hmac_key),
                    "C_DestroyObject() failed");
    }
    return true;
}

int create_aes_key(struct PKCS_Context *ctx, char *label) {
    CK_MECHANISM aesKeyMechanism;
    aesKeyMechanism.mechanism = CKM_AES_KEY_GEN;

    CK_BBOOL true_var = CK_TRUE;
    ulong key_type = CKK_AES;
    ulong value_len = TLS_CRYPT_V2_KEY_LEN;
    CK_ATTRIBUTE attr_list[] = {
            {CKA_TOKEN,     &true_var,  sizeof(true_var)},
            {CKA_LABEL,     label,      strlen(label)},
            {CKA_ENCRYPT,   &true_var,  sizeof(true_var)},
            {CKA_DECRYPT,   &true_var,  sizeof(true_var)},
            {CKA_KEY_TYPE,  &key_type,  sizeof(key_type)},
            {CKA_VALUE_LEN, &value_len, sizeof(value_len)}
    };
    CK_OBJECT_HANDLE aes_key;
    ERROR_CHECK(ctx->p11_functions->C_GenerateKey(ctx->session, &aesKeyMechanism, attr_list,
                                                  sizeof(attr_list) / sizeof(attr_list[0]), &aes_key),
                "C_GenerateKey() failed");

    return true;
}

int create_hmac_key(struct PKCS_Context *ctx, char *label) {
    CK_MECHANISM hmacKeyMechanism;
    hmacKeyMechanism.mechanism = CKM_GENERIC_SECRET_KEY_GEN;

    CK_BBOOL true_var = CK_TRUE;
    ulong key_type = CKK_GENERIC_SECRET;
    ulong value_len = TLS_CRYPT_V2_KEY_LEN;
    CK_ATTRIBUTE attr_list[] = {
            {CKA_TOKEN,     &true_var,  sizeof(true_var)},
            {CKA_LABEL,     label,      strlen(label)},
            {CKA_SIGN,      &true_var,  sizeof(true_var)},
            {CKA_KEY_TYPE,  &key_type,  sizeof(key_type)},
            {CKA_VALUE_LEN, &value_len, sizeof(value_len)}
    };
    CK_OBJECT_HANDLE hmac_key;
    ERROR_CHECK(ctx->p11_functions->C_GenerateKey(ctx->session, &hmacKeyMechanism, attr_list,
                                                  sizeof(attr_list) / sizeof(attr_list[0]), &hmac_key),
                "C_GenerateKey() failed");

    return true;
}

int load_pkcs11_functions(struct PKCS_Context *ctx, const char *lib_path) {
    CK_C_GetFunctionList function_symbol_list = NULL;

    /* Get handle to the library */
    ctx->lib_handle = dlopen(lib_path, RTLD_NOW);
    ERROR_CHECK(!ctx->lib_handle,
                "dlopen() failed");

    /* Obtain address of a symbols in a shared library */
    function_symbol_list = (CK_C_GetFunctionList) dlsym(ctx->lib_handle, "C_GetFunctionList");

    /* Get pkcs11 function list pointer */
    ERROR_CHECK(!function_symbol_list,
                "dlsym() failed");

    ERROR_CHECK((function_symbol_list(&ctx->p11_functions) != CKR_OK
                 || !ctx->p11_functions),
                "function_symbol_list() failed");

    // Initialize PKCS 11 function library
    ERROR_CHECK(ctx->p11_functions->C_Initialize(NULL_PTR) != CKR_OK,
                "C_Initialize() failed");

    return true;
}


int connect_to_pkcs11_token(struct plugin_ctx *ctx) {
    CK_SLOT_ID *slot_list = (CK_SLOT_ID_PTR) calloc(SLOT_LIST_MAX_SIZE, sizeof(CK_SLOT_ID));
    CK_ULONG slot_list_size = SLOT_LIST_MAX_SIZE;
    ERROR_CHECK(ctx->pkcs_ctx->p11_functions->C_GetSlotList(true, slot_list, &slot_list_size),
                "C_GetSlotList() failed");
    ERROR_CHECK(slot_list_size < 1,
                "No Slots / Tokens present!");

    ERROR_CHECK(
            ctx->pkcs_ctx->p11_functions->C_OpenSession(slot_list[0], CKF_SERIAL_SESSION | CKF_RW_SESSION, NULL, NULL,
                                                        &ctx->pkcs_ctx->session),
            "C_OpenSession() failed");

    ctx->pkcs_ctx->slot = slot_list[0];

    ERROR_CHECK(ctx->pkcs_ctx->p11_functions->C_Login(ctx->pkcs_ctx->session, CKU_USER, (CK_UTF8CHAR_PTR) ctx->pin,
                                                strlen(ctx->pin)),
                "Could not log into the token!");

    return true;
}

int verify_supported_aes_length(struct PKCS_Context *pkcs_ctx, CK_MECHANISM_TYPE mechanism) {
    CK_MECHANISM_INFO mech_info;
    pkcs_ctx->p11_functions->C_GetMechanismInfo(pkcs_ctx->slot, mechanism, &mech_info);
    return mech_info.ulMaxKeySize >= TLS_CRYPT_V2_KEY_LEN && mech_info.ulMinKeySize <= TLS_CRYPT_V2_KEY_LEN;
}

int determine_token_capabilities(struct plugin_ctx *ctx) {
    unsigned long mech_size;
    CK_RV res;
    res = ctx->pkcs_ctx->p11_functions->C_GetMechanismList(ctx->pkcs_ctx->slot, NULL, &mech_size);
    if (res != CKR_OK) {
        return false;
    }
    CK_MECHANISM_TYPE *list = calloc(mech_size, sizeof(CK_MECHANISM_TYPE));
    res = ctx->pkcs_ctx->p11_functions->C_GetMechanismList(ctx->pkcs_ctx->slot, list, &mech_size);
    if (res != CKR_OK) {
        return false;
    }

    bool ctr_support = false;
    bool ecb_support = false;
    bool aes_support = false;
    bool hmac_sha256_support = false;
    for (int i = 0; i < mech_size; i++) {
        CK_MECHANISM_TYPE mech = list[i];
        switch (mech) {
            case CKM_AES_CTR:
                aes_support = true;
                ctr_support = true;
                break;
            case CKM_AES_ECB:
                if (verify_supported_aes_length(ctx->pkcs_ctx, get_aes_ecb_mechanism().mechanism))
                {
                    aes_support = true;
                    ecb_support = true;
                    ctx->pkcs_ctx->aes_mechanism = get_aes_ecb_mechanism();
                }
                break;
            case CKM_AES_CBC:
                if (verify_supported_aes_length(ctx->pkcs_ctx, get_aes_cbc_mechanism().mechanism))
                {
                    aes_support = true;
                    ctx->pkcs_ctx->aes_mechanism = get_aes_cbc_mechanism();
                }
                break;
            case CKM_SHA256_HMAC:
                hmac_sha256_support = true;
                break;
            default:
                // Don't Care
                break;
        }
    }

    if (aes_support && hmac_sha256_support) {
        ctx->cipher_function = &perform_aes;
        ctx->authentication_function = &perform_sha_256_hmac;
        ctx->cipher_key_generation = &create_aes_key;
        ctx->authentication_key_generation = &create_hmac_key;
        plog(PLOG_NOTE, "Token does support AES");
        plog(PLOG_NOTE, "Token does support HMAC-SHA256");
    } else if (!aes_support && hmac_sha256_support) {
        ctx->cipher_function = &perform_aes_key_derivation;
        ctx->authentication_function = &perform_sha_256_hmac;
        ctx->cipher_key_generation = &create_hmac_key;
        ctx->authentication_key_generation = &create_hmac_key;
        plog(PLOG_WARN, "Token does not support AES. Using key derivation function.");
        plog(PLOG_NOTE, "Token does support HMAC-SHA256");
    } else if (aes_support) {
        ctx->cipher_function = &perform_aes;
        ctx->authentication_function = &perform_hmac_key_derivation;
        ctx->cipher_key_generation = &create_aes_key;
        ctx->authentication_key_generation = &create_aes_key;
        plog(PLOG_NOTE, "Token does support AES");
        plog(PLOG_WARN, "Token does not support HMAC-SHA256. Using key derivation function.");
    }

    if (ecb_support) {
        ctx->cipher_function = &perform_aes_with_ecb;
        plog(PLOG_NOTE, "Token does support AES-256-ECB");
    }

    if (ctr_support) {
        ctx->cipher_function = &perform_aes_256_ctr;
        plog(PLOG_NOTE, "Token does support AES-256-CTR");
    }

    if (!ctx->cipher_function || !ctx->authentication_function) {
        // Token does not support needed capabilities
        plog(PLOG_ERR, "Token does not support needed mechanisms!");
        return false;
    }

    return true;
}

int pkcs11_startup(struct plugin_ctx *plugin_ctx, const char *lib_path) {
    ERROR_CHECK(load_pkcs11_functions(plugin_ctx->pkcs_ctx, lib_path) != true,
                "Failed loading pkcs library!");

    ERROR_CHECK(connect_to_pkcs11_token(plugin_ctx) != true,
                "Failed connecting to token!");

    ERROR_CHECK(determine_token_capabilities(plugin_ctx) != true,
                "determine_token_capabilities() failed!");

    return true;
}

int disconnect_from_pkcs11_token(struct PKCS_Context *pkcs_ctx) {
    return pkcs_ctx->p11_functions->C_CloseSession(pkcs_ctx->session) == CKR_OK;
}

int try_reconnect(struct plugin_ctx *plugin_ctx) {
    disconnect_from_pkcs11_token(plugin_ctx->pkcs_ctx);
    return connect_to_pkcs11_token(plugin_ctx);
}

int main() {
    struct plugin_ctx plugin_ctx;
    plugin_ctx.pin = "1234";
    struct PKCS_Context ctx;
    plugin_ctx.pkcs_ctx = &ctx;
    if (!pkcs11_startup(&plugin_ctx, "/usr/lib/pkcs11/libsofthsm2.so")) {
        return 1;
    }

    CK_MECHANISM_INFO info;
    ctx.p11_functions->C_GetMechanismInfo(ctx.slot, CKM_AES_CBC, &info);
    disconnect_from_pkcs11_token(&ctx);


} /* main */