Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
openvpn_pkcs11_key_wrapping.c 9.68 KiB
#include <stdio.h>
#include <string.h>

#include "shared.h"
#include "pkcs11_handler.h"

#define OPENVPN_PLUGIN_VERSION_MIN 3
#define OPENVPN_PLUGIN_STRUCTVER_MIN 5

OPENVPN_EXPORT int
openvpn_plugin_min_version_required_v1()
{
    return OPENVPN_PLUGIN_VERSION_MIN;
}

OPENVPN_EXPORT int
openvpn_plugin_select_initialization_point_v1()
{
    return OPENVPN_PLUGIN_INIT_PRE_CONFIG_PARSE;
}

OPENVPN_EXPORT int
openvpn_plugin_open_v3(const int v3structver,
                       struct openvpn_plugin_args_open_in const *args,
                       struct openvpn_plugin_args_open_return *ret)
{
    if (v3structver < OPENVPN_PLUGIN_STRUCTVER_MIN)
    {
        printf("Error: This plugin is incompatible with the running version of OpenVPN!");
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }

    // Save global pointers to functions exported from openvpn
    plugin_vlog_func = args->callbacks->plugin_vlog;
    ovpn_base64_decode = args->callbacks->plugin_base64_decode;
    ovpn_base64_encode = args->callbacks->plugin_base64_encode;
    ovpn_secure_memzero = args->callbacks->plugin_secure_memzero;

    if (!args->argv[1] || !args->argv[2])
    {
        plog(PLOG_ERR, "Missing arguments for plugin missing!");
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }

    struct plugin_ctx *plugin_ctx = (struct plugin_ctx *) calloc(1, sizeof(struct plugin_ctx));
    if (!plugin_ctx)
    {
        goto error;
    }

    if(strcmp(args->argv[2], "0") != 0)
    {
        plugin_ctx->pin = strdup(args->argv[2]);
    }
    else
    {
        char pin[17] = { 0 };
        printf("Please enter the User PIN for the connected PKCS#11 Token (up to 16 characters): ");
        scanf("%16s", pin);
        plugin_ctx->pin = strdup(pin);
    }

    plugin_ctx->pkcs_ctx = calloc(1, sizeof(struct PKCS_Context));
    int startup_success = pkcs11_startup(plugin_ctx, args->argv[1]);
    if(!startup_success)
    {
        goto error;
    }

    // Load keys on token, don't care if none are found. May have not been generated yet.
    find_keys(plugin_ctx);

    // Which callbacks to intercept.
    ret->type_mask = OPENVPN_PLUGIN_MASK(OPENVPN_PLUGIN_CLIENT_KEY_WRAPPING);
    ret->handle = (openvpn_plugin_handle_t *) plugin_ctx;

    return OPENVPN_PLUGIN_FUNC_SUCCESS;

error:
    plog(PLOG_NOTE, "Initialization failed");
    return OPENVPN_PLUGIN_FUNC_ERROR;
}

static uint16_t bytesToShort(const unsigned char *bytes)
{
    return *(bytes) << 8 | *(bytes + 1);
}

static int
try_cipher(struct plugin_ctx *plugin_ctx,
            const unsigned char *data, unsigned long data_len,
            unsigned char *out, unsigned long *out_len,
            const unsigned char *nonce)
{
    if(!plugin_ctx->cipher_function(plugin_ctx->pkcs_ctx, data, data_len,
                                    out, out_len, nonce))
    {
        try_reconnect(plugin_ctx);
        return plugin_ctx->cipher_function(plugin_ctx->pkcs_ctx, data, data_len,
                                           out, out_len, nonce);
    }
    return true;
}

static int
handle_return(struct openvpn_plugin_args_func_return *ret,
              const void *data, int data_len)
{
    struct openvpn_plugin_string_list *rl = calloc(1,sizeof(struct openvpn_plugin_string_list));
    if(!rl)
        return false;

    rl->name = strdup("wrapping result");
    int b64_size = ovpn_base64_encode(data, data_len, &rl->value);
    if(b64_size < 0)
        return false;

    struct openvpn_plugin_string_list **ret_list = ret->return_list;
    *ret_list = rl;

    return true;
}

/**
 *  Unwrap a client key by using keys derived from a HMAC-SHA1 key stored inside a YubiKey
 *
 *  @return int    Returns OPENVPN_PLUGIN_FUNC_ERROR on error and OPENVPN_PLUGIN_FUNC_SUCCESS on success
 *
 */
static int
pkcs11_unwrap(struct plugin_ctx *plugin_ctx, const char **argv,
              struct openvpn_plugin_args_func_return *ret)
{
    unsigned char wkc[TLS_CRYPT_V2_MAX_WKC_LEN]   = { 0 };
    unsigned char kc[TLS_CRYPT_V2_MAX_WKC_LEN];
    unsigned char tag[TLS_CRYPT_V2_TAG_LEN]       = { 0 };
    int exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
    uint16_t wkc_len, kc_len, net_len, tag_len;

    // Decode WKC from argv
    const char *wkc_base64 = argv[2];
    plog(PLOG_DEBUG, "Received WKc: %s", wkc_base64);

    wkc_len = ovpn_base64_decode(wkc_base64, wkc, TLS_CRYPT_V2_MAX_WKC_LEN);
    CRYPTO_ECHECK(wkc_len < 0,
                  "ovpn_base64_decode failed");

    // Length checks
    kc_len = wkc_len - TLS_CRYPT_V2_TAG_LEN - TLS_CRYPT_V2_LEN_LEN;
    CRYPTO_ECHECK(kc_len < 0,
                  "Invalid Length of WKc");
    net_len = bytesToShort(wkc + wkc_len - 2);
    CRYPTO_ECHECK(net_len != wkc_len,
                  "Invalid Declaration of Length for WKc");

    // Calculate AES Key with Yubikey
    CRYPTO_ECHECK(!try_cipher(plugin_ctx, wkc + TLS_CRYPT_V2_TAG_LEN, kc_len,
                              kc, (unsigned long *) &kc_len, wkc),
                  "Couldn't decrypt WKc");

    // Calculate tag and compare
    tag_len = TLS_CRYPT_V2_TAG_LEN;
    CRYPTO_ECHECK(!plugin_ctx->authentication_function(plugin_ctx->pkcs_ctx, kc, kc_len, tag, (unsigned long *) &tag_len),
                  "Couldn't calculate tag");

    CRYPTO_ECHECK(memcmp(tag, wkc, TLS_CRYPT_V2_TAG_LEN) != 0,
                  "Tags don't match");

    // Prepare return for openvpn
    CRYPTO_ECHECK(!handle_return(ret, kc, kc_len),
                  "Returning results failed");

    exit_code = OPENVPN_PLUGIN_FUNC_SUCCESS;
error_exit:
    if(exit_code != OPENVPN_PLUGIN_FUNC_SUCCESS)
    {
        ovpn_secure_memzero(kc, sizeof(kc));
    }
    return exit_code;
}

/**
 *  Wrap a client key by using keys derived from a HMAC-SHA1 key stored inside a YubiKey
 *
 *  @return int    Returns OPENVPN_PLUGIN_FUNC_ERROR on error and OPENVPN_PLUGIN_FUNC_SUCCESS on success
 *
 */
static int
pkcs11_wrap(struct plugin_ctx *plugin_ctx, const char **argv,
            struct openvpn_plugin_args_func_return *ret)
{
    unsigned char kc[TLS_CRYPT_V2_MAX_WKC_LEN] = { 0 };
    unsigned char wkc[TLS_CRYPT_V2_MAX_WKC_LEN] = { 0 };
    unsigned char tag[TLS_CRYPT_V2_TAG_LEN] = { 0 };
    int exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
    uint16_t wkc_len, kc_len, tag_len;

    // Decode KC from argv
    const char *kc_base64 = argv[2];
    kc_len = ovpn_base64_decode(kc_base64, kc, TLS_CRYPT_V2_MAX_WKC_LEN);

    // Calculate tag
    tag_len = TLS_CRYPT_V2_TAG_LEN;
    CRYPTO_ECHECK(!plugin_ctx->authentication_function(plugin_ctx->pkcs_ctx, kc, kc_len, tag, (unsigned long *) &tag_len),
                  "Couldn't calculate tag");

    // Create WKc
    memcpy(wkc, tag, TLS_CRYPT_V2_TAG_LEN);
    CRYPTO_ECHECK(!plugin_ctx->cipher_function(plugin_ctx->pkcs_ctx, kc, kc_len, wkc + TLS_CRYPT_V2_TAG_LEN,
                                               (unsigned long *) &kc_len, tag),
                  "Couldn't encrypt client key");

    wkc_len = kc_len + TLS_CRYPT_V2_TAG_LEN + TLS_CRYPT_V2_LEN_LEN;
    wkc[wkc_len - 2] = (wkc_len >> 8) & 0xFF;
    wkc[wkc_len - 1] = wkc_len & 0xFF;

    // Prepare return for OpenVPN
    CRYPTO_ECHECK(!handle_return(ret, wkc, wkc_len),
                  "Returning results failed");
    exit_code = OPENVPN_PLUGIN_FUNC_SUCCESS;

error_exit:
    return exit_code;
}

/**
 *  Import AES and HMAC server keys onto the PKCS#11 token
 *
 *  @return int    Returns OPENVPN_PLUGIN_FUNC_ERROR on error and OPENVPN_PLUGIN_FUNC_SUCCESS on success
 *
 */
static int
pkcs11_generate_server_key(struct plugin_ctx *plugin_ctx) {
    delete_aes_key(plugin_ctx->pkcs_ctx);
    delete_hmac_key(plugin_ctx->pkcs_ctx);

    return plugin_ctx->cipher_key_generation(plugin_ctx->pkcs_ctx, AES_KEY_LABEL) && plugin_ctx->authentication_key_generation(plugin_ctx->pkcs_ctx, HMAC_KEY_LABEL);
}

int
openvpn_plugin_func_v3(const int v3structver,
                       struct openvpn_plugin_args_func_in const *args,
                       struct openvpn_plugin_args_func_return *ret)
{
    if (v3structver < OPENVPN_PLUGIN_STRUCTVER_MIN)
    {
        fprintf(stderr, "%s: this plugin is incompatible with the running version of OpenVPN\n", MODULE);
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }
    const char **argv = args->argv;
    struct plugin_ctx *plugin_context = (struct plugin_ctx *) args->handle;
    int exit_code;

    if(args->type != OPENVPN_PLUGIN_CLIENT_KEY_WRAPPING)
    {
        plog(PLOG_NOTE, "OPENVPN_PLUGIN_?");
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }

    if(strcmp(argv[1], "unwrap") == 0)
    {
        plog(PLOG_NOTE, "Unwrapping Client Key with PKCS#11");
        exit_code = pkcs11_unwrap(plugin_context, argv, ret);
        if (exit_code != OPENVPN_PLUGIN_FUNC_SUCCESS) {
            try_reconnect(plugin_context);
            exit_code = pkcs11_unwrap(plugin_context, argv, ret);
        }
    }
    else if (strcmp(argv[1], "wrap") == 0)
    {
        plog(PLOG_NOTE, "Wrapping Client Key with PKCS#11");
        exit_code = pkcs11_wrap(plugin_context, argv, ret);
    }
    else if (strcmp(argv[1], "import") == 0)
    {
        plog(PLOG_NOTE, "Importing Server Key onto PKCS#11 Token");
        exit_code = pkcs11_generate_server_key(plugin_context);
    }
    else
    {
        exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
    }

    return exit_code;
}

void
openvpn_plugin_close_v1(openvpn_plugin_handle_t handle)
{
    struct plugin_ctx *context = (struct plugin_ctx *) handle;
    ovpn_secure_memzero((char*) context->pin, strlen(context->pin) + 1);
    free((char*) context->pin);
    context->pin = NULL;
    if(context->pkcs_ctx)
    {
        disconnect_from_pkcs11_token(context->pkcs_ctx);
        dlclose(context->pkcs_ctx->lib_handle);
        if(context->pkcs_ctx)
        {
            free(context->pkcs_ctx->aes_mechanism.pParameter);
            context->pkcs_ctx->aes_mechanism.pParameter = NULL;
            free(context->pkcs_ctx);
            context->pkcs_ctx = NULL;
        }
    }

}