#include "pkcs11_handler.h"

void increment_byte_array(unsigned char *array, int array_length)
{
    int carry = 1;
    for (int i = array_length - 1; i >= 0 && carry == 1; i--)
    {
        array[i] += carry;
        carry = array[i] == 0 ? 1 : 0;
    }
}

void xor_byte_array(unsigned char *a_array, const unsigned char *b_array, unsigned long array_length)
{
    for (unsigned long i = 0; i < array_length; i++)
    {
        a_array[i] ^= b_array[i];
    }
}

CK_MECHANISM
get_aes_ecb_mechanism()
{
    CK_MECHANISM aesEncryptMechanism;
    aesEncryptMechanism.mechanism = CKM_AES_ECB;
    aesEncryptMechanism.pParameter = NULL;
    aesEncryptMechanism.ulParameterLen = 0;
    return aesEncryptMechanism;
}

CK_MECHANISM
get_aes_cbc_mechanism()
{
    CK_MECHANISM aesEncryptMechanism;
    aesEncryptMechanism.mechanism = CKM_AES_CBC;
    aesEncryptMechanism.pParameter = calloc(AES_BLOCK_SIZE, 1);
    aesEncryptMechanism.ulParameterLen = AES_BLOCK_SIZE;
    return aesEncryptMechanism;
}

int perform_aes_key_derivation(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len,
                               unsigned char *out, unsigned long *out_len, const unsigned char *iv)
{
    unsigned char aes_client_server_key[TLS_CRYPT_V2_KEY_LEN] = {0};
    unsigned long key_len = TLS_CRYPT_V2_KEY_LEN;
    int exit_code = false;

    CRYPTO_ECHECK(! perform_sha_256_hmac_keyfreedom(
                      pkcs_ctx, iv, TLS_CRYPT_V2_TAG_LEN, aes_client_server_key, &key_len, pkcs_ctx->aes_key),
                  "perform_sha_256_hmac_keyfreedom() failed");

    EVP_CIPHER_CTX *evp_ctx = EVP_CIPHER_CTX_new();
    CRYPTO_ECHECK(! evp_ctx,
                  "EVP_CIPHER_CTX_new() failed");

    CRYPTO_ECHECK(! EVP_CIPHER_CTX_init(evp_ctx),
                  "EVP_CIPHER_CTX_init() failed");
    CRYPTO_ECHECK(! EVP_EncryptInit_ex2(evp_ctx, EVP_aes_256_ctr(), aes_client_server_key, iv, NULL),
                  "EVP_EncryptInit_ex2() failed");
    CRYPTO_ECHECK(! EVP_EncryptUpdate(evp_ctx, out, (int *) out_len, data, (int) data_len),
                  "EVP_EncryptUpdate() failed");
    EVP_CIPHER_CTX_free(evp_ctx);

    exit_code = true;

error_exit:
    ovpn_secure_memzero(aes_client_server_key, sizeof(aes_client_server_key));
    return exit_code;
}

int perform_aes_with_ecb(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len,
                         unsigned char *out, unsigned long *out_len, const unsigned char *nonce)
{
    int block_count = ceil(((double) data_len) / ((double) AES_BLOCK_SIZE));
    unsigned char iv[AES_BLOCK_SIZE];
    unsigned char iv_sequence[AES_BLOCK_SIZE * block_count];

    memcpy(iv, nonce, AES_BLOCK_SIZE);
    memcpy(iv_sequence, nonce, AES_BLOCK_SIZE);
    memcpy(out, data, data_len);

    for (int block_index = 1; block_index < block_count; ++block_index)
    {
        increment_byte_array(iv, AES_BLOCK_SIZE);
        memcpy(iv_sequence + block_index * AES_BLOCK_SIZE, iv, AES_BLOCK_SIZE);
    }

    ERROR_CHECK(! encrypt_block_n(pkcs_ctx, iv_sequence, iv_sequence, pkcs_ctx->aes_key, block_count),
                "encrypt_block_n() failed");

    xor_byte_array(out, iv_sequence, data_len);
    *out_len = data_len;

    return true;
}

int perform_aes(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len, unsigned char *out,
                unsigned long *out_len, const unsigned char *nonce)
{
    unsigned char encrypted_iv[AES_BLOCK_SIZE];
    unsigned char iv[AES_BLOCK_SIZE];
    int block_size, processed;

    memcpy(iv, nonce, AES_BLOCK_SIZE);
    memcpy(out, data, data_len);

    for (processed = 0; processed < data_len; processed += block_size)
    {
        block_size = (data_len - processed < AES_BLOCK_SIZE) ? (int) data_len - processed : AES_BLOCK_SIZE;
        ERROR_CHECK(! encrypt_block_n(pkcs_ctx, iv, encrypted_iv, pkcs_ctx->aes_key, 1),
                    "encrypt_block_n() failed");
        xor_byte_array(out, encrypted_iv, block_size);
        increment_byte_array(iv, AES_BLOCK_SIZE);
    }

    *out_len = processed;

    return true;
}

int perform_aes_256_ctr(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned long data_len,
                        unsigned char *out, unsigned long *out_len, const unsigned char *nonce)
{
    CK_MECHANISM aesEncryptMechanism;
    CK_AES_CTR_PARAMS params;
    memcpy(params.cb, nonce, TLS_CRYPT_V2_IV_LENGTH);
    params.ulCounterBits = TLS_CRYPT_V2_IV_LENGTH * 8;
    aesEncryptMechanism.mechanism = CKM_AES_CTR;
    aesEncryptMechanism.pParameter = &params;
    aesEncryptMechanism.ulParameterLen = sizeof(params);

    ERROR_CHECK(pkcs_ctx->p11_functions->C_EncryptInit(pkcs_ctx->session, &aesEncryptMechanism, pkcs_ctx->aes_key),
                "C_EncryptInit() failed");

    ERROR_CHECK(pkcs_ctx->p11_functions->C_Encrypt(pkcs_ctx->session, (CK_BYTE_PTR) data, data_len, out, out_len),
                "C_Encrypt() failed");

    return true;
}

int encrypt_block_n(struct PKCS_Context *pkcs_ctx, const unsigned char *data, unsigned char *out,
                    CK_OBJECT_HANDLE key_handle, int block_count)
{
    ERROR_CHECK(pkcs_ctx->p11_functions->C_EncryptInit(pkcs_ctx->session, &pkcs_ctx->aes_mechanism, key_handle),
                "C_EncryptInit() failed");

    CK_ULONG out_length = AES_BLOCK_SIZE * block_count;
    ERROR_CHECK(pkcs_ctx->p11_functions->C_Encrypt(
                    pkcs_ctx->session, (CK_BYTE_PTR) data, AES_BLOCK_SIZE * block_count, out, &out_length),
                "C_Encrypt() failed");

    return true;
}

int perform_hmac_key_derivation(struct PKCS_Context *ctx, const unsigned char *data, unsigned long data_len,
                                unsigned char *out, unsigned long *out_len)
{
    unsigned char hmac_client_server_key[TLS_CRYPT_V2_KEY_LEN] = {0};
    unsigned char data_buffer[TLS_CRYPT_V2_MAX_WKC_LEN] = {0};
    int exit_code = false;

    encrypt_block_n(ctx, data, hmac_client_server_key, ctx->hmac_key, 2);

    data_buffer[0] = (data_len >> 8) & 0xFF;
    data_buffer[1] = data_len & 0xFF;
    memcpy(data_buffer + 2, data, data_len);

    CRYPTO_ECHECK(! HMAC(EVP_sha256(),
                         hmac_client_server_key,
                         TLS_CRYPT_V2_KEY_LEN,
                         data_buffer,
                         data_len + TLS_CRYPT_V2_LEN_LEN,
                         out,
                         (unsigned int *) out_len),
                  "HMAC() failed");
    exit_code = true;
error_exit:
    ovpn_secure_memzero(hmac_client_server_key, sizeof(hmac_client_server_key));
    return exit_code;
}

int perform_sha_256_hmac_keyfreedom(struct PKCS_Context *ctx, const unsigned char *data, unsigned long data_len,
                                    unsigned char *out, unsigned long *out_len, CK_OBJECT_HANDLE key_handle)
{
    ERROR_CHECK(! ctx->hmac_key,
                "HMAC Key not loaded");

    CK_MECHANISM hmacMechanism;
    hmacMechanism.mechanism = CKM_SHA256_HMAC;

    ERROR_CHECK(ctx->p11_functions->C_SignInit(ctx->session, &hmacMechanism, key_handle),
                "C_SignInit() failed");

    ERROR_CHECK(ctx->p11_functions->C_Sign(ctx->session, (CK_BYTE_PTR) data, data_len, out, out_len),
                "C_Sign() failed");

    return true;
}

int perform_sha_256_hmac(struct PKCS_Context *ctx, const unsigned char *data, unsigned long data_len,
                         unsigned char *out, unsigned long *out_len)
{
    return perform_sha_256_hmac_keyfreedom(ctx, data, data_len, out, out_len, ctx->hmac_key);
}

int find_key(struct PKCS_Context *ctx, CK_OBJECT_HANDLE *key, CK_ATTRIBUTE *attr_list, ulong attr_list_len)
{
    CK_ULONG ulObjectCount;

    ERROR_CHECK(ctx->p11_functions->C_FindObjectsInit(ctx->session, attr_list, attr_list_len),
                "C_FindObjectsInit() failed");

    ERROR_CHECK(ctx->p11_functions->C_FindObjects(ctx->session, key, MAX_OBJECT_COUNT, &ulObjectCount),
                "C_FindObjects() failed");

    ERROR_CHECK(ctx->p11_functions->C_FindObjectsFinal(ctx->session),
                "C_FindObjectsFinal() failed");

    if (ulObjectCount != 1)
    {
        plog(PLOG_DEBUG, "No key was found");
        return false;
    }

    return true;
}

int find_key_by_label(struct PKCS_Context *pkcs_ctx, CK_OBJECT_HANDLE *key_handle, char *label)
{
    CK_BBOOL true_var = CK_TRUE;
    ulong value_len = TLS_CRYPT_V2_KEY_LEN;
    CK_ATTRIBUTE attr_list[] = {{CKA_TOKEN, &true_var, sizeof(true_var)},
                                {CKA_LABEL, label, strlen(label)},
                                {CKA_VALUE_LEN, &value_len, sizeof(value_len)}};
    return find_key(pkcs_ctx, key_handle, attr_list, sizeof(attr_list) / sizeof(attr_list[0]));
}

int find_aes_key(struct PKCS_Context *ctx, CK_OBJECT_HANDLE *key)
{
    return find_key_by_label(ctx, key, AES_KEY_LABEL);
}

int find_hmac_key(struct PKCS_Context *ctx, CK_OBJECT_HANDLE *key)
{
    return find_key_by_label(ctx, key, HMAC_KEY_LABEL);
}

int find_keys(struct plugin_ctx *plugin_ctx)
{
    return find_aes_key(plugin_ctx->pkcs_ctx, &plugin_ctx->pkcs_ctx->aes_key) &&
           find_hmac_key(plugin_ctx->pkcs_ctx, &plugin_ctx->pkcs_ctx->hmac_key);
}

int delete_aes_key(struct PKCS_Context *ctx)
{
    CK_OBJECT_HANDLE aes_key;
    while (find_aes_key(ctx, &aes_key))
    {
        ERROR_CHECK(ctx->p11_functions->C_DestroyObject(ctx->session, aes_key),
                    "C_DestroyObject() failed");
    }
    return true;
}

int delete_hmac_key(struct PKCS_Context *ctx)
{
    CK_OBJECT_HANDLE hmac_key;
    while (find_hmac_key(ctx, &hmac_key))
    {
        ERROR_CHECK(ctx->p11_functions->C_DestroyObject(ctx->session, hmac_key),
                    "C_DestroyObject() failed");
    }
    return true;
}

int create_aes_key(struct PKCS_Context *ctx, char *label)
{
    CK_MECHANISM aesKeyMechanism;
    aesKeyMechanism.mechanism = CKM_AES_KEY_GEN;

    CK_BBOOL true_var = CK_TRUE;
    ulong key_type = CKK_AES;
    ulong value_len = TLS_CRYPT_V2_KEY_LEN;
    CK_ATTRIBUTE attr_list[] = {{CKA_TOKEN, &true_var, sizeof(true_var)},
                                {CKA_LABEL, label, strlen(label)},
                                {CKA_ENCRYPT, &true_var, sizeof(true_var)},
                                {CKA_KEY_TYPE, &key_type, sizeof(key_type)},
                                {CKA_VALUE_LEN, &value_len, sizeof(value_len)}};
    CK_OBJECT_HANDLE aes_key;
    ERROR_CHECK(ctx->p11_functions->C_GenerateKey(
                    ctx->session, &aesKeyMechanism, attr_list, sizeof(attr_list) / sizeof(attr_list[0]), &aes_key),
                "C_GenerateKey() failed");
    return true;
}

int create_hmac_key(struct PKCS_Context *ctx, char *label)
{
    CK_MECHANISM hmacKeyMechanism;
    hmacKeyMechanism.mechanism = CKM_GENERIC_SECRET_KEY_GEN;

    CK_BBOOL true_var = CK_TRUE;
    ulong key_type = CKK_GENERIC_SECRET;
    ulong value_len = TLS_CRYPT_V2_KEY_LEN;
    CK_ATTRIBUTE attr_list[] = {{CKA_TOKEN, &true_var, sizeof(true_var)},
                                {CKA_LABEL, label, strlen(label)},
                                {CKA_SIGN, &true_var, sizeof(true_var)},
                                {CKA_KEY_TYPE, &key_type, sizeof(key_type)},
                                {CKA_VALUE_LEN, &value_len, sizeof(value_len)}};
    CK_OBJECT_HANDLE hmac_key;
    ERROR_CHECK(ctx->p11_functions->C_GenerateKey(
                    ctx->session, &hmacKeyMechanism, attr_list, sizeof(attr_list) / sizeof(attr_list[0]), &hmac_key),
                "C_GenerateKey() failed");

    return true;
}

int load_pkcs11_functions(struct PKCS_Context *ctx, const char *lib_path)
{
    CK_C_GetFunctionList function_symbol_list = NULL;

    // Get handle to the library
    ctx->lib_handle = dlopen(lib_path, RTLD_NOW);
    ERROR_CHECK(! ctx->lib_handle,
                "dlopen() failed");

    // Obtain address of symbols in shared library
    function_symbol_list = (CK_C_GetFunctionList) dlsym(ctx->lib_handle, "C_GetFunctionList");

    // Get pkcs11 function list pointer
    ERROR_CHECK(! function_symbol_list,
                "dlsym() failed");

    ERROR_CHECK((function_symbol_list(&ctx->p11_functions) != CKR_OK || ! ctx->p11_functions),
                "function_symbol_list() failed");

    // Initialize PKCS 11 function library
    ERROR_CHECK(ctx->p11_functions->C_Initialize(NULL_PTR) != CKR_OK,
                "C_Initialize() failed");

    return true;
}

int connect_to_pkcs11_token(struct plugin_ctx *ctx)
{
    CK_SLOT_ID slot_list[SLOT_LIST_MAX_SIZE] = {0};
    CK_ULONG slot_list_size = SLOT_LIST_MAX_SIZE;
    ERROR_CHECK(ctx->pkcs_ctx->p11_functions->C_GetSlotList(true, slot_list, &slot_list_size),
                "C_GetSlotList() failed");
    ERROR_CHECK(slot_list_size < 1,
                "No Slots / Tokens present! Is the PKCS#11-capable token connected?");

    ERROR_CHECK(ctx->pkcs_ctx->p11_functions->C_OpenSession(
                    slot_list[0], CKF_SERIAL_SESSION | CKF_RW_SESSION, NULL, NULL, &ctx->pkcs_ctx->session),
                "C_OpenSession() failed");

    ctx->pkcs_ctx->slot = slot_list[0];

    ERROR_CHECK(ctx->pkcs_ctx->p11_functions->C_Login(
                    ctx->pkcs_ctx->session, CKU_USER, (CK_UTF8CHAR_PTR) ctx->pin, strlen(ctx->pin)),
                "Could not log into the token!");

    return true;
}

int verify_supported_aes_length(struct PKCS_Context *pkcs_ctx, CK_MECHANISM_TYPE mechanism)
{
    CK_MECHANISM_INFO mech_info;
    pkcs_ctx->p11_functions->C_GetMechanismInfo(pkcs_ctx->slot, mechanism, &mech_info);
    return mech_info.ulMaxKeySize >= TLS_CRYPT_V2_KEY_LEN && mech_info.ulMinKeySize <= TLS_CRYPT_V2_KEY_LEN;
}

int determine_token_capabilities(struct plugin_ctx *ctx)
{
    unsigned long mech_size;
    CK_RV res;

    res = ctx->pkcs_ctx->p11_functions->C_GetMechanismList(ctx->pkcs_ctx->slot, NULL, &mech_size);
    ERROR_CHECK(res != CKR_OK, "C_GetMechanismList() 1 failed");

    CK_MECHANISM_TYPE mech_list[mech_size];
    res = ctx->pkcs_ctx->p11_functions->C_GetMechanismList(ctx->pkcs_ctx->slot, mech_list, &mech_size);
    ERROR_CHECK(res != CKR_OK, "C_GetMechanismList() 2 failed");

    bool ctr_support, ecb_support, cbc_support, hmac_sha256_support = false;
    for (int i = 0; i < mech_size; i++)
    {
        CK_MECHANISM_TYPE mech = mech_list[i];
        switch (mech)
        {
            case CKM_AES_CTR:
                ctr_support = true;
                break;
            case CKM_AES_ECB:
                ecb_support = true;
                break;
            case CKM_AES_CBC:
                cbc_support = true;
                break;
            case CKM_SHA256_HMAC:
                hmac_sha256_support = true;
                break;
            default:
                // Don't Care
                break;
        }
    }

    if (ctr_support)
    {
        ctx->cipher_function = &perform_aes_256_ctr;
        ctx->cipher_key_generation = &create_aes_key;
        plog(PLOG_NOTE, "Token does support AES-256-CTR");
    }
    else if (ecb_support)
    {
        ctx->cipher_function = &perform_aes_with_ecb;
        ctx->cipher_key_generation = &create_aes_key;
        ctx->pkcs_ctx->aes_mechanism = get_aes_ecb_mechanism();
        plog(PLOG_NOTE, "Token does support AES-256-ECB");
    }
    else if (cbc_support)
    {
        ctx->cipher_function = &perform_aes;
        ctx->cipher_key_generation = &create_aes_key;
        ctx->pkcs_ctx->aes_mechanism = get_aes_cbc_mechanism();
        plog(PLOG_NOTE, "Token does support AES-256-CBC");
    }
    else if (hmac_sha256_support)
    {
        ctx->cipher_function = &perform_aes_key_derivation;
        ctx->cipher_key_generation = &create_hmac_key;
        plog(PLOG_WARN, "Token does not support AES. Using key derivation function.");
    }

    if ((ecb_support || cbc_support) &&
        ! verify_supported_aes_length(ctx->pkcs_ctx, ctx->pkcs_ctx->aes_mechanism.mechanism))
    {
        ctx->cipher_function = NULL;
        ctx->cipher_key_generation = NULL;
        if (hmac_sha256_support)
        {
            ctx->cipher_function = &perform_aes_key_derivation;
            ctx->cipher_key_generation = &create_hmac_key;
            plog(PLOG_NOTE, "Token does not support required AES key length. "
                            "Falling back to key derivation");
        }
    }

    if (hmac_sha256_support)
    {
        ctx->authentication_function = &perform_sha_256_hmac;
        ctx->authentication_key_generation = &create_hmac_key;
        plog(PLOG_NOTE, "Token does support HMAC-SHA256");
    }
    else if (ecb_support || cbc_support)
    {
        ctx->authentication_function = &perform_hmac_key_derivation;
        ctx->authentication_key_generation = &create_aes_key;
        plog(PLOG_WARN, "Token does not support HMAC-SHA256. Using key derivation function.");
    }

    ERROR_CHECK(! ctx->cipher_function || ! ctx->authentication_function,
                "Token does not support required mechanisms!");

    return true;
}

int pkcs11_startup(struct plugin_ctx *plugin_ctx, const char *lib_path)
{
    ERROR_CHECK(! load_pkcs11_functions(plugin_ctx->pkcs_ctx, lib_path),
                "Failed loading pkcs library!");

    ERROR_CHECK(! connect_to_pkcs11_token(plugin_ctx),
                "Failed connecting to token!");

    ERROR_CHECK(! determine_token_capabilities(plugin_ctx),
                "determine_token_capabilities() failed!");

    return true;
}

int disconnect_from_pkcs11_token(struct PKCS_Context *pkcs_ctx)
{
    return pkcs_ctx->p11_functions->C_CloseSession(pkcs_ctx->session) == CKR_OK;
}

int reconnect_to_pkcs11_token(struct plugin_ctx *plugin_ctx)
{
    disconnect_from_pkcs11_token(plugin_ctx->pkcs_ctx);
    return connect_to_pkcs11_token(plugin_ctx);
}

int main()
{
    struct plugin_ctx plugin_ctx;
    plugin_ctx.pin = "1234";
    struct PKCS_Context ctx;
    plugin_ctx.pkcs_ctx = &ctx;
    if (! pkcs11_startup(&plugin_ctx, "/usr/lib/pkcs11/libsofthsm2.so"))
    {
        return 1;
    }

    CK_MECHANISM_INFO info;
    ctx.p11_functions->C_GetMechanismInfo(ctx.slot, CKM_AES_CBC, &info);
    disconnect_from_pkcs11_token(&ctx);

} /* main */