Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
openvpn_pkcs11_key_wrapping.c 10.25 KiB
#include <stdio.h>
#include <string.h>
#include <sys/time.h>

#include "shared.h"
#include "pkcs11_handler.h"

#define OPENVPN_PLUGIN_VERSION_MIN 3
#define OPENVPN_PLUGIN_STRUCTVER_MIN 5

/*
 * Given an environmental variable name, search
 * the envp array for its value, returning it
 * if found or NULL otherwise.
 */
static const char *
get_env(const char *name, const char *envp[])
{
    if (envp)
    {
        int i;
        const unsigned int namelen = strlen(name);
        for (i = 0; envp[i]; ++i)
        {
            if (!strncmp(envp[i], name, namelen))
            {
                const char *cp = envp[i] + namelen;
                if (*cp == '=')
                {
                    return cp + 1;
                }
            }
        }
    }
    return NULL;
}

OPENVPN_EXPORT int
openvpn_plugin_min_version_required_v1()
{
    return OPENVPN_PLUGIN_VERSION_MIN;
}

OPENVPN_EXPORT int
openvpn_plugin_select_initialization_point_v1()
{
    return OPENVPN_PLUGIN_INIT_PRE_CONFIG_PARSE;
}

OPENVPN_EXPORT int
openvpn_plugin_open_v3(const int v3structver,
                       struct openvpn_plugin_args_open_in const *args,
                       struct openvpn_plugin_args_open_return *ret)
{
    if (v3structver < OPENVPN_PLUGIN_STRUCTVER_MIN)
    {
        printf("Error: This plugin is incompatible with the running version of OpenVPN!");
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }

    /* Save global pointers to functions exported from openvpn */
    plugin_vlog_func = args->callbacks->plugin_vlog;
    ovpn_base64_decode = args->callbacks->plugin_base64_decode;
    ovpn_base64_encode = args->callbacks->plugin_base64_encode;
    ovpn_secure_memzero = args->callbacks->plugin_secure_memzero;

    if (!args->argv[1] || !args->argv[2])
    {
        plog(PLOG_ERR, "Missing arguments for plugin missing!");
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }

    const char **argv = args->argv;
    struct plugin_ctx *plugin_ctx = NULL;
    plugin_ctx = (struct plugin_ctx *) calloc(1, sizeof(struct plugin_ctx));
    if (!plugin_ctx)
    {
        goto error;
    }

    plugin_ctx->pin = args->argv[2];

    plugin_ctx->pkcs_ctx = calloc(1, sizeof(struct PKCS_Context));
    if (!pkcs11_startup(plugin_ctx, args->argv[1]))
    {
        goto error;
    }

    /*
     * Get verbosity level from environment
     */
    {
        const char *verb_string = get_env("verb", args->envp);
        if (verb_string)
        {
            plugin_ctx->verb = (int) strtol(verb_string, NULL, 10);
        }
    }

    /*
     * Which callbacks to intercept.
     */
    ret->type_mask = OPENVPN_PLUGIN_MASK(OPENVPN_PLUGIN_CLIENT_KEY_WRAPPING);
    ret->handle = (openvpn_plugin_handle_t *) plugin_ctx;

    return OPENVPN_PLUGIN_FUNC_SUCCESS;

error:
    plog(PLOG_NOTE, "initialization failed");
    return OPENVPN_PLUGIN_FUNC_ERROR;
}

static uint16_t bytesToShort(const unsigned char *bytes)
{
    return *(bytes) << 8 | *(bytes + 1);
}

static int
handle_return(struct openvpn_plugin_args_func_return *ret,
              const void *data, int data_len)
{
    struct openvpn_plugin_string_list *rl = calloc(1,sizeof(struct openvpn_plugin_string_list));
    if(!rl)
        return false;

    rl->name = strdup("wrapping result");
    int b64_size = ovpn_base64_encode(data, data_len, &rl->value);
    if(b64_size < 0)
        return false;

    struct openvpn_plugin_string_list **ret_list = ret->return_list;
    *ret_list = rl;

    return true;
}

/**
 *  Unwrap a client key by using keys derived from a HMAC-SHA1 key stored inside a YubiKey
 *
 *  @return int    Returns OPENVPN_PLUGIN_FUNC_ERROR on error and OPENVPN_PLUGIN_FUNC_SUCCESS on success
 *
 */
static int
softhsm_unwrap(struct plugin_ctx *plugin_ctx, const char **argv,
               struct openvpn_plugin_args_func_return *ret)
{
    unsigned char wkc[TLS_CRYPT_V2_MAX_WKC_LEN]   = { 0 };
    unsigned char kc[TLS_CRYPT_V2_MAX_WKC_LEN];
    unsigned char tag[TLS_CRYPT_V2_TAG_LEN]       = { 0 };
    int exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
    uint16_t wkc_len, kc_len, net_len, tag_len;

    // Decode WKC from argv
    const char *wkc_base64 = argv[2];
    plog(PLOG_DEBUG, "Received WKc: %s", wkc_base64);

    wkc_len = ovpn_base64_decode(wkc_base64, wkc, TLS_CRYPT_V2_MAX_WKC_LEN);
    CRYPTO_ECHECK(wkc_len < 0,
                  "ovpn_base64_decode failed");

    // Length checks
    kc_len = wkc_len - TLS_CRYPT_V2_TAG_LEN - TLS_CRYPT_V2_LEN_LEN;
    CRYPTO_ECHECK(kc_len < 0,
                  "Invalid Length of WKc");
    net_len = bytesToShort(wkc + wkc_len - 2);
    CRYPTO_ECHECK(net_len != wkc_len,
                  "Invalid Declaration of Length for WKc");

    // Load keys on token
    if(!find_keys(plugin_ctx))
    {
        // Perhaps token got disconnected?
        try_reconnect(plugin_ctx);
        CRYPTO_ECHECK(!find_keys(plugin_ctx),
                      "Couldn't load keys on token");
    }

    // Calculate AES Key with Yubikey
    CRYPTO_ECHECK(!plugin_ctx->cipher_function(plugin_ctx->pkcs_ctx, wkc + TLS_CRYPT_V2_TAG_LEN, kc_len,
                                               kc, (unsigned long *) &kc_len, wkc),
                  "Couldn't decrypt WKc");

    // Calculate tag and compare
    tag_len = TLS_CRYPT_V2_TAG_LEN;
    CRYPTO_ECHECK(!plugin_ctx->authentication_function(plugin_ctx->pkcs_ctx, kc, kc_len, tag, (unsigned long *) &tag_len),
                  "Couldn't calculate tag");

    CRYPTO_ECHECK(memcmp(tag, wkc, TLS_CRYPT_V2_TAG_LEN) != 0,
                  "Tags don't match");

    // Prepare return for openvpn
    CRYPTO_ECHECK(!handle_return(ret, kc, kc_len),
                  "Returning results failed");

    exit_code = OPENVPN_PLUGIN_FUNC_SUCCESS;
    error_exit:
    return exit_code;
}

/**
 *  Wrap a client key by using keys derived from a HMAC-SHA1 key stored inside a YubiKey
 *
 *  @return int    Returns OPENVPN_PLUGIN_FUNC_ERROR on error and OPENVPN_PLUGIN_FUNC_SUCCESS on success
 *
 */
static int
softhsm_wrap(struct plugin_ctx *plugin_ctx, const char **argv,
             struct openvpn_plugin_args_func_return *ret)
{
    unsigned char kc[TLS_CRYPT_V2_MAX_WKC_LEN] = { 0 };
    unsigned char wkc[TLS_CRYPT_V2_MAX_WKC_LEN] = { 0 };
    unsigned char tag[TLS_CRYPT_V2_TAG_LEN] = { 0 };
    int exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
    uint16_t wkc_len, kc_len, tag_len;

    // Decode KC from argv
    const char *kc_base64 = argv[2];
    kc_len = ovpn_base64_decode(kc_base64, kc, TLS_CRYPT_V2_MAX_WKC_LEN);

    // Load keys on token
    CRYPTO_ECHECK(!find_keys(plugin_ctx),
                  "Couldn't load keys on token");

    // Calculate tag
    tag_len = TLS_CRYPT_V2_TAG_LEN;
    CRYPTO_ECHECK(!plugin_ctx->authentication_function(plugin_ctx->pkcs_ctx, kc, kc_len, tag, (unsigned long *) &tag_len),
                  "Couldn't calculate tag");

    // Create WKc
    memcpy(wkc, tag, TLS_CRYPT_V2_TAG_LEN);
    CRYPTO_ECHECK(!plugin_ctx->cipher_function(plugin_ctx->pkcs_ctx, kc, kc_len, wkc + TLS_CRYPT_V2_TAG_LEN,
                                               (unsigned long *) &kc_len, tag),
                  "Couldn't encrypt client key");

    wkc_len = kc_len + TLS_CRYPT_V2_TAG_LEN + TLS_CRYPT_V2_LEN_LEN;
    wkc[wkc_len - 2] = (wkc_len >> 8) & 0xFF;
    wkc[wkc_len - 1] = wkc_len & 0xFF;

    // Prepare return for openvpn
    CRYPTO_ECHECK(!handle_return(ret, wkc, wkc_len),
                  "Returning results failed");
    exit_code = OPENVPN_PLUGIN_FUNC_SUCCESS;

error_exit:
    return exit_code;
}

/**
 *  Import first 20 Bytes of generated AES Server Key as HMAC-SHA1 Key into SoftHSM
 *
 *  @return int    Returns OPENVPN_PLUGIN_FUNC_ERROR on error and OPENVPN_PLUGIN_FUNC_SUCCESS on success
 *
 */
static int
softhsm_generate_server_key(struct plugin_ctx *plugin_ctx, const char **argv) {
    int res;
    delete_aes_key(plugin_ctx->pkcs_ctx);
    delete_hmac_key(plugin_ctx->pkcs_ctx);
    // Depending on if a key derivation is performed need keys of the opposite type
    if(plugin_ctx->cipher_function == &perform_aes_key_derivation)
    {
        res = create_hmac_key(plugin_ctx->pkcs_ctx, AES_KEY_LABEL);
    }
    else
    {
        res = create_aes_key(plugin_ctx->pkcs_ctx, AES_KEY_LABEL);
    }
    
    if(plugin_ctx->authentication_function == &perform_hmac_key_derivation)
    {
        res &= create_aes_key(plugin_ctx->pkcs_ctx, HMAC_KEY_LABEL);
    }
    else
    {
        res &= create_hmac_key(plugin_ctx->pkcs_ctx, HMAC_KEY_LABEL);
    }

    return res ? OPENVPN_PLUGIN_FUNC_SUCCESS : OPENVPN_PLUGIN_FUNC_ERROR;
}
int
openvpn_plugin_func_v3(const int v3structver,
                       struct openvpn_plugin_args_func_in const *args,
                       struct openvpn_plugin_args_func_return *ret)
{
    struct timeval start, end;
    gettimeofday(&start, NULL);

    if (v3structver < OPENVPN_PLUGIN_STRUCTVER_MIN)
    {
        fprintf(stderr, "%s: this plugin is incompatible with the running version of OpenVPN\n", MODULE);
        return OPENVPN_PLUGIN_FUNC_ERROR;
    }
    const char **argv = args->argv;
    struct plugin_ctx *plugin_context = (struct plugin_ctx *) args->handle;
    int exit_code;

    if(args->type != OPENVPN_PLUGIN_CLIENT_KEY_WRAPPING)
    {
        plog(PLOG_NOTE, "OPENVPN_PLUGIN_?");
        exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
        goto end;
    }

    if(strcmp(argv[1], "unwrap") == 0)
    {
        plog(PLOG_NOTE, "Unwrapping Client Key with SoftHSM");
        exit_code = softhsm_unwrap(plugin_context, argv, ret);
    }
    else if (strcmp(argv[1], "wrap") == 0)
    {
        plog(PLOG_NOTE, "Wrapping Client Key with SoftHSM");
        exit_code = softhsm_wrap(plugin_context, argv, ret);
    }
    else if (strcmp(argv[1], "import") == 0)
    {
        plog(PLOG_NOTE, "Importing Server Key to SoftHSM");
        exit_code = softhsm_generate_server_key(plugin_context, argv);
    }
    else
    {
        exit_code = OPENVPN_PLUGIN_FUNC_ERROR;
    }

end:
    gettimeofday(&end, NULL);
    plog(PLOG_DEBUG, "Operation took : %ld micro seconds\n",
         ((end.tv_sec * 1000000 + end.tv_usec) - (start.tv_sec * 1000000 + start.tv_usec)));
    return exit_code;
}

void
openvpn_plugin_close_v1(openvpn_plugin_handle_t handle)
{
    struct plugin_ctx *context = (struct plugin_ctx *) handle;
    if(context->pkcs_ctx)
    {
        disconnect_from_pkcs11_token(context->pkcs_ctx);
        dlclose(context->pkcs_ctx->lib_handle);
        free(context->pkcs_ctx);
        context->pkcs_ctx = NULL_PTR;
    }

}