From 3b5ba04cb6e5f61428cc9ae86bffa7757fae3d1d Mon Sep 17 00:00:00 2001
From: huitema <huitema@huitema.net>
Date: Sun, 3 Jun 2018 21:58:45 -0700
Subject: [PATCH] Add PN encryption test vector, and fix issues in PN
 encryption key computation.

---
 UnitTest1/unittest1.cpp            |  7 +++
 picoquic/tls_api.c                 | 67 +++++++++++++++-----------
 picoquic/tls_api.h                 |  2 +-
 picoquic/util.c                    |  2 +-
 picoquic/util.h                    |  2 +-
 picoquic_t/picoquic_t.c            |  3 +-
 picoquictest/cleartext_aead_test.c | 77 ++++++++++++++++++++++++++++++
 picoquictest/picoquictest.h        |  1 +
 8 files changed, 128 insertions(+), 33 deletions(-)

diff --git a/UnitTest1/unittest1.cpp b/UnitTest1/unittest1.cpp
index 08c802f1..4f67a4d5 100644
--- a/UnitTest1/unittest1.cpp
+++ b/UnitTest1/unittest1.cpp
@@ -541,5 +541,12 @@ namespace UnitTest1
             Assert::AreEqual(ret, 0);
         }
 
+        TEST_METHOD(test_pn_vector)
+        {
+            int ret = cleartext_pn_vector_test();
+
+            Assert::AreEqual(ret, 0);
+        }
+
     };
 }
diff --git a/picoquic/tls_api.c b/picoquic/tls_api.c
index 91784c89..68139ac9 100644
--- a/picoquic/tls_api.c
+++ b/picoquic/tls_api.c
@@ -1015,6 +1015,35 @@ int picoquic_initialize_stream_zero(picoquic_cnx_t* cnx)
     return ret;
 }
 
+/*
+ * QUIC Specific HKDF Function
+ */
+
+int picoquic_hkdf_expand_label(ptls_hash_algorithm_t *algo, void *output, size_t outlen, ptls_iovec_t secret,
+    const char *label, const char *base_label)
+{
+    ptls_buffer_t hkdf_label;
+    uint8_t hkdf_label_buf[512];
+    int ret;
+
+    ptls_buffer_init(&hkdf_label, hkdf_label_buf, sizeof(hkdf_label_buf));
+
+    ptls_buffer_push16(&hkdf_label, (uint16_t)outlen);
+    ptls_buffer_push_block(&hkdf_label, 1, {
+        if (base_label == NULL)
+        base_label = "tls13 ";
+    ptls_buffer_pushv(&hkdf_label, base_label, strlen(base_label));
+    ptls_buffer_pushv(&hkdf_label, label, strlen(label));
+        });
+
+    ret = ptls_hkdf_expand(algo, output, outlen, secret, ptls_iovec_init(hkdf_label.base, hkdf_label.off));
+
+Exit:
+    ptls_buffer_dispose(&hkdf_label);
+    return ret;
+}
+
+
 /*
  * Packet number encryption and decryption utilities
  */
@@ -1034,8 +1063,12 @@ void * picoquic_pn_enc_create(
     /*
      * Derive the key by extending the secret for PN encryption 
      */
+#if 0
     ret = ptls_hkdf_expand(
         hash, key, aead->key_size, ptls_iovec_init(secret, hash->digest_size), ptls_iovec_init("pn", 2));
+#else
+    ret = picoquic_hkdf_expand_label(hash, key, aead->key_size, ptls_iovec_init(secret, hash->digest_size), "pn", base_label);
+#endif
 
     /*
      * Create the context. This is always an encryptng context, because of the stream cipher mode.
@@ -1059,7 +1092,7 @@ void * picoquic_pn_enc_create_for_test(const uint8_t * secret)
     return ret;
 }
 
-void picoquic_pn_encrypt(void *pn_enc, void * iv, void *output, const void *input, size_t len)
+void picoquic_pn_encrypt(void *pn_enc, const void * iv, void *output, const void *input, size_t len)
 {
     ptls_cipher_init((ptls_cipher_context_t *) pn_enc, iv);
     ptls_cipher_encrypt((ptls_cipher_context_t *) pn_enc, output, input, len);
@@ -1086,30 +1119,6 @@ revert to using ptls_aead_new.
 
 */
 
-int picoquic_hkdf_expand_label(ptls_hash_algorithm_t *algo, void *output, size_t outlen, ptls_iovec_t secret, 
-    const char *label, const char *base_label)
-{
-    ptls_buffer_t hkdf_label;
-    uint8_t hkdf_label_buf[512];
-    int ret;
-
-    ptls_buffer_init(&hkdf_label, hkdf_label_buf, sizeof(hkdf_label_buf));
-
-    ptls_buffer_push16(&hkdf_label, (uint16_t)outlen);
-    ptls_buffer_push_block(&hkdf_label, 1, {
-        if (base_label == NULL)
-            base_label = "tls13 ";
-        ptls_buffer_pushv(&hkdf_label, base_label, strlen(base_label));
-        ptls_buffer_pushv(&hkdf_label, label, strlen(label));
-        });
-
-    ret = ptls_hkdf_expand(algo, output, outlen, secret, ptls_iovec_init(hkdf_label.base, hkdf_label.off));
-
-Exit:
-    ptls_buffer_dispose(&hkdf_label);
-    return ret;
-}
-
 static int picoquic_get_traffic_key(ptls_hash_algorithm_t *algo, void *key, size_t key_size, int is_iv, const void *secret,
     const char *base_label)
 {
@@ -1227,7 +1236,7 @@ int picoquic_setup_1RTT_aead_contexts(picoquic_cnx_t* cnx, int is_server)
         if (ret == 0) {
             cnx->aead_encrypt_ctx = (void*)
                 picoquic_aead_new(cipher->aead, cipher->hash, 1, secret, PICOQUIC_QUIC_BASE_LABEL);
-            cnx->pn_enc = picoquic_pn_enc_create(cipher->aead, cipher->hash, secret, NULL);
+            cnx->pn_enc = picoquic_pn_enc_create(cipher->aead, cipher->hash, secret, PICOQUIC_QUIC_BASE_LABEL);
 
             if (cnx->aead_encrypt_ctx == NULL) {
                 ret = PICOQUIC_ERROR_MEMORY;
@@ -1246,7 +1255,7 @@ int picoquic_setup_1RTT_aead_contexts(picoquic_cnx_t* cnx, int is_server)
 
         if (ret == 0) {
             cnx->aead_decrypt_ctx = (void*)picoquic_aead_new(cipher->aead, cipher->hash, 0, secret, PICOQUIC_QUIC_BASE_LABEL);
-            cnx->pn_dec = picoquic_pn_enc_create(cipher->aead, cipher->hash, secret, NULL);
+            cnx->pn_dec = picoquic_pn_enc_create(cipher->aead, cipher->hash, secret, PICOQUIC_QUIC_BASE_LABEL);
 
             if (cnx->aead_decrypt_ctx == NULL) {
                 ret = -1;
@@ -1442,8 +1451,8 @@ int picoquic_setup_cleartext_aead_contexts(picoquic_cnx_t* cnx)
             cnx->aead_de_encrypt_cleartext_ctx = (void*)
                 picoquic_aead_new(aead, algo, 0, secret1, PICOQUIC_QUIC_BASE_LABEL);
 
-            cnx->pn_enc_cleartext = picoquic_pn_enc_create(aead, algo, secret1, NULL);
-            cnx->pn_dec_cleartext = picoquic_pn_enc_create(aead, algo, secret2, NULL);
+            cnx->pn_enc_cleartext = picoquic_pn_enc_create(aead, algo, secret1, PICOQUIC_QUIC_BASE_LABEL);
+            cnx->pn_dec_cleartext = picoquic_pn_enc_create(aead, algo, secret2, PICOQUIC_QUIC_BASE_LABEL);
         }
     }
 
diff --git a/picoquic/tls_api.h b/picoquic/tls_api.h
index 1cc65d16..e7798e8d 100644
--- a/picoquic/tls_api.h
+++ b/picoquic/tls_api.h
@@ -63,7 +63,7 @@ size_t picoquic_aead_decrypt_generic(uint8_t* output, uint8_t* input, size_t inp
 
 void picoquic_aead_free(void* aead_context);
 
-void picoquic_pn_encrypt(void *pn_enc, void * iv, void *output, const void *input, size_t len);
+void picoquic_pn_encrypt(void *pn_enc, const void * iv, void *output, const void *input, size_t len);
 
 void picoquic_pn_enc_free(void * pn_enc);
 
diff --git a/picoquic/util.c b/picoquic/util.c
index e199cdbf..232bfed1 100644
--- a/picoquic/util.c
+++ b/picoquic/util.c
@@ -161,7 +161,7 @@ uint32_t picoquic_format_connection_id(uint8_t* bytes, size_t bytes_max, picoqui
     return copied;
 }
 
-uint32_t picoquic_parse_connection_id(uint8_t * bytes, uint8_t len, picoquic_connection_id_t * cnx_id)
+uint32_t picoquic_parse_connection_id(const uint8_t * bytes, uint8_t len, picoquic_connection_id_t * cnx_id)
 {
     if (len <= PICOQUIC_CONNECTION_ID_MAX_SIZE) {
         cnx_id->id_len = len;
diff --git a/picoquic/util.h b/picoquic/util.h
index f151a757..de451654 100644
--- a/picoquic/util.h
+++ b/picoquic/util.h
@@ -46,7 +46,7 @@ void debug_printf_resume(void);
 
 extern const picoquic_connection_id_t picoquic_null_connection_id;
 uint32_t picoquic_format_connection_id(uint8_t* bytes, size_t bytes_max, picoquic_connection_id_t cnx_id);
-uint32_t picoquic_parse_connection_id(uint8_t* bytes, uint8_t len, picoquic_connection_id_t *cnx_id);
+uint32_t picoquic_parse_connection_id(const uint8_t* bytes, uint8_t len, picoquic_connection_id_t *cnx_id);
 int picoquic_is_connection_id_null(picoquic_connection_id_t cnx_id);
 int picoquic_compare_connection_id(picoquic_connection_id_t * cnx_id1, picoquic_connection_id_t * cnx_id2);
 uint64_t picoquic_val64_connection_id(picoquic_connection_id_t cnx_id);
diff --git a/picoquic_t/picoquic_t.c b/picoquic_t/picoquic_t.c
index 14599a9f..e0914f2b 100644
--- a/picoquic_t/picoquic_t.c
+++ b/picoquic_t/picoquic_t.c
@@ -106,7 +106,8 @@ static const picoquic_test_def_t test_table[] = {
     { "nat_rebinding_loss", nat_rebinding_loss_test },
     { "spin_bit", spin_bit_test},
     { "client_error", client_error_test },
-    { "packet_enc_dec", packet_enc_dec_test}
+    { "packet_enc_dec", packet_enc_dec_test},
+    { "pn_vector", cleartext_pn_vector_test }
 };
 
 static size_t const nb_tests = sizeof(test_table) / sizeof(picoquic_test_def_t);
diff --git a/picoquictest/cleartext_aead_test.c b/picoquictest/cleartext_aead_test.c
index a068808b..971c9678 100644
--- a/picoquictest/cleartext_aead_test.c
+++ b/picoquictest/cleartext_aead_test.c
@@ -24,6 +24,7 @@
 #endif
 #include "../picoquic/picoquic_internal.h"
 #include "../picoquic/tls_api.h"
+#include "../picoquic/util.h"
 #include "picotls.h"
 #include "picotls/openssl.h"
 #include <string.h>
@@ -513,5 +514,81 @@ int cleartext_pn_enc_test()
         picoquic_free(qserver);
     }
 
+    return ret;
+}
+
+/* Test vector copied from Kazuho Ohu's test code in quicly */
+
+int cleartext_pn_vector_test()
+{
+    int ret = 0;
+    static const uint8_t cid[] = { 0x77, 0x0d, 0xc2, 0x6c, 0x17, 0x50, 0x9b, 0x35 };
+    static const uint8_t sample[] = { 0x05, 0x80, 0x24, 0xa9, 0x72, 0x75, 0xf0, 0x1d, 0x2a, 0x1e, 0xc9, 0x1f, 0xd1, 0xc2, 0x65, 0xbb };
+    static const uint8_t encrypted_pn[] = { 0x3b, 0xb4, 0xb1, 0x74 };
+    static const uint8_t expected_pn[] = { 0xc0, 0x00, 0x00, 0x00 };
+
+    struct sockaddr_in test_addr_s;
+    picoquic_connection_id_t initial_cnxid;
+    picoquic_cnx_t* cnx_server = NULL;
+    picoquic_quic_t* qserver = picoquic_create(8,
+#ifdef _WINDOWS
+#ifdef _WINDOWS64
+        "..\\..\\certs\\cert.pem", "..\\..\\certs\\key.pem",
+#else
+        "..\\certs\\cert.pem", "..\\certs\\key.pem",
+#endif
+#else
+        "certs/cert.pem", "certs/key.pem",
+#endif
+        "test", NULL, NULL, NULL, NULL, NULL, 0, NULL, NULL, NULL, 0);
+    if (qserver == NULL) {
+        DBG_PRINTF("%s", "Could not create Quic contexts.\n");
+        ret = -1;
+    }
+
+    if (ret == 0 && picoquic_parse_connection_id(cid, sizeof(cid), &initial_cnxid) != sizeof(cid)) {
+        ret = -1;
+    }
+
+
+    if (ret == 0) {
+
+        memset(&test_addr_s, 0, sizeof(struct sockaddr_in));
+        test_addr_s.sin_family = AF_INET;
+        memcpy(&test_addr_s.sin_addr, addr2, 4);
+        test_addr_s.sin_port = 4433;
+
+        cnx_server = picoquic_create_cnx(qserver, initial_cnxid, initial_cnxid,
+            (struct sockaddr*)&test_addr_s, 0, PICOQUIC_SIXTH_INTEROP_VERSION, NULL, NULL, 0);
+
+        if (cnx_server == NULL) {
+            DBG_PRINTF("%s", "Could not create server connection context.\n");
+            ret = -1;
+        }
+    }
+
+    /* Try to decrypt the test vector */
+    if (ret == 0) {
+        uint8_t decrypted[8];
+
+        memset(decrypted, 0, sizeof(decrypted));
+
+        picoquic_pn_encrypt(cnx_server->pn_dec_cleartext, sample, decrypted, encrypted_pn, sizeof(encrypted_pn));
+
+        if (memcmp(decrypted, expected_pn, sizeof(expected_pn)) != 0)
+        {
+            DBG_PRINTF("%s", "Test of encoding PN vector failed.\n");
+            ret = -1;
+        }
+    }
+
+    if (cnx_server != NULL) {
+        picoquic_delete_cnx(cnx_server);
+    }
+
+    if (qserver != NULL) {
+        picoquic_free(qserver);
+    }
+
     return ret;
 }
\ No newline at end of file
diff --git a/picoquictest/picoquictest.h b/picoquictest/picoquictest.h
index 8233626a..3fa06327 100644
--- a/picoquictest/picoquictest.h
+++ b/picoquictest/picoquictest.h
@@ -101,6 +101,7 @@ int nat_rebinding_loss_test();
 int spin_bit_test();
 int client_error_test();
 int packet_enc_dec_test();
+int cleartext_pn_vector_test();
 
 #ifdef __cplusplus
 }
-- 
GitLab