From adca3791fff5262bfdb6af9f8504b54262948cd6 Mon Sep 17 00:00:00 2001
From: huitema <huitema@huitema.net>
Date: Tue, 9 Oct 2018 22:51:02 -0700
Subject: [PATCH] Implement key rotation and test

---
 UnitTest1/unittest1.cpp      |   7 ++
 picoquic/packet.c            |  59 +++++++++++--
 picoquic/picoquic.h          |   5 +-
 picoquic/picoquic_internal.h |   4 +
 picoquic/quicctx.c           |  13 +++
 picoquic/sender.c            |   2 +-
 picoquic/tls_api.c           |  53 ++++++++++--
 picoquic/tls_api.h           |   1 +
 picoquic_t/picoquic_t.c      |   3 +-
 picoquictest/picoquictest.h  |   1 +
 picoquictest/tls_api_test.c  | 157 +++++++++++++++++++++++++++++++++++
 11 files changed, 290 insertions(+), 15 deletions(-)

diff --git a/UnitTest1/unittest1.cpp b/UnitTest1/unittest1.cpp
index 336afbfd..6eeac3ab 100644
--- a/UnitTest1/unittest1.cpp
+++ b/UnitTest1/unittest1.cpp
@@ -693,6 +693,13 @@ namespace UnitTest1
 
             Assert::AreEqual(ret, 0);
         }
+
+        TEST_METHOD(key_rotation)
+        {
+            int ret = key_rotation_test();
+
+            Assert::AreEqual(ret, 0);
+        }
         
         TEST_METHOD(stress)
         {
diff --git a/picoquic/packet.c b/picoquic/packet.c
index bc7f49c1..b6482521 100644
--- a/picoquic/packet.c
+++ b/picoquic/packet.c
@@ -492,11 +492,60 @@ int picoquic_parse_header_and_decrypt(
                     (*pcnx)->crypto_context[1].aead_decrypt, &already_received);
                 break;
             case picoquic_packet_1rtt_protected:
-                /* TODO : roll key based on PHI */
-                /* AEAD Decrypt, in place */
-                decoded_length = picoquic_decrypt_packet(*pcnx, bytes, length, ph,
-                    (*pcnx)->crypto_context[3].pn_dec,
-                    (*pcnx)->crypto_context[3].aead_decrypt, &already_received);
+                if (ph->key_phase == (*pcnx)->key_phase_dec) {
+                    /* AEAD Decrypt, in place */
+                    decoded_length = picoquic_decrypt_packet(*pcnx, bytes, length, ph,
+                        (*pcnx)->crypto_context[3].pn_dec,
+                        (*pcnx)->crypto_context[3].aead_decrypt, &already_received);
+                }
+                else {
+                    if ((*pcnx)->crypto_context_old.aead_decrypt != NULL &&
+                        (*pcnx)->crypto_context_old.pn_dec != NULL &&
+                        current_time < (*pcnx)->crypto_rotation_time_guard)
+                    {
+                        /* If there is an old key available, try decrypt with it */
+                        decoded_length = picoquic_decrypt_packet(*pcnx, bytes, length, ph,
+                            (*pcnx)->crypto_context_old.pn_dec,
+                            (*pcnx)->crypto_context_old.aead_decrypt, &already_received);
+
+                        if (decoded_length <= (length - ph->offset) &&
+                            ph->pn64 > (*pcnx)->crypto_rotation_sequence) {
+                            ret = picoquic_connection_error(*pcnx, PICOQUIC_TRANSPORT_PROTOCOL_VIOLATION, 0);
+                        }
+                    }
+                    else {
+                        /* These could only be a new key */
+                        if ((*pcnx)->crypto_context_new.aead_decrypt == NULL &&
+                            (*pcnx)->crypto_context_new.aead_encrypt == NULL &&
+                            (*pcnx)->crypto_context_new.pn_dec == NULL &&
+                            (*pcnx)->crypto_context_new.pn_enc == NULL) {
+                            /* If the new context was already computed, don't do it again */
+                            ret = picoquic_compute_new_rotated_keys(*pcnx);
+                        }
+
+                        if ((*pcnx)->crypto_context_new.aead_decrypt != NULL &&
+                            (*pcnx)->crypto_context_new.pn_dec != NULL)
+                        {
+                            /* If there is an old key available, try decrypt with it */
+                            decoded_length = picoquic_decrypt_packet(*pcnx, bytes, length, ph,
+                                (*pcnx)->crypto_context_new.pn_dec,
+                                (*pcnx)->crypto_context_new.aead_decrypt, &already_received);
+
+                            if (decoded_length <= (length - ph->offset)) {
+                                /* Rotation only if the packet was correctly decrypted with the new key */
+                                (*pcnx)->crypto_rotation_time_guard = current_time + (*pcnx)->path[0]->retransmit_timer;
+                                (*pcnx)->crypto_rotation_sequence = ph->pn64;
+                                picoquic_apply_rotated_keys(*pcnx, 0);
+
+                                if ((*pcnx)->crypto_context_new.aead_encrypt != NULL &&
+                                    (*pcnx)->crypto_context_new.pn_enc != NULL) {
+                                    /* If that move was not already validated, move to the new encryption keys */
+                                    picoquic_apply_rotated_keys(*pcnx, 1);
+                                }
+                            }
+                        }
+                    }
+                }
                 break;
             default:
                 /* Packet type error. Log and ignore */
diff --git a/picoquic/picoquic.h b/picoquic/picoquic.h
index a27aeceb..282d0789 100644
--- a/picoquic/picoquic.h
+++ b/picoquic/picoquic.h
@@ -72,7 +72,8 @@ extern "C" {
 #define PICOQUIC_ERROR_CONNECTION_DELETED (PICOQUIC_ERROR_CLASS + 31)
 #define PICOQUIC_ERROR_CNXID_SEGMENT (PICOQUIC_ERROR_CLASS + 32)
 #define PICOQUIC_ERROR_CNXID_NOT_AVAILABLE (PICOQUIC_ERROR_CLASS + 33)
-#define PICOQUIC_ERROR_MIGRATION_DISABLED (PICOQUIC_ERROR_CLASS + 33)
+#define PICOQUIC_ERROR_MIGRATION_DISABLED (PICOQUIC_ERROR_CLASS + 34)
+#define PICOQUIC_ERROR_CANNOT_COMPUTE_KEY (PICOQUIC_ERROR_CLASS + 35)
 
 /*
  * Protocol errors defined in the QUIC spec
@@ -366,6 +367,8 @@ int picoquic_create_probe(picoquic_cnx_t* cnx, const struct sockaddr* addr_to, c
 
 int picoquic_renew_connection_id(picoquic_cnx_t* cnx);
 
+int picoquic_start_key_rotation(picoquic_cnx_t * cnx);
+
 picoquic_cnx_t* picoquic_get_first_cnx(picoquic_quic_t* quic);
 picoquic_cnx_t* picoquic_get_next_cnx(picoquic_cnx_t* cnx);
 int64_t picoquic_get_next_wake_delay(picoquic_quic_t* quic,
diff --git a/picoquic/picoquic_internal.h b/picoquic/picoquic_internal.h
index 5d70269a..b75e91db 100644
--- a/picoquic/picoquic_internal.h
+++ b/picoquic/picoquic_internal.h
@@ -541,6 +541,8 @@ typedef struct st_picoquic_cnx_t {
     unsigned int is_0RTT_accepted : 1; /* whether 0-RTT is accepted */
     unsigned int remote_parameters_received : 1; /* whether remote parameters where received */
     unsigned int client_mode : 1; /* Is this connection the client side? */
+    unsigned int key_phase_enc : 1; /* Key phase used in outgoing packets */
+    unsigned int key_phase_dec : 1; /* Key phase expected in incoming packets */
 
 
     /* Local and remote parameters */
@@ -577,6 +579,8 @@ typedef struct st_picoquic_cnx_t {
 
     /* TLS context, TLS Send Buffer, streams, epochs */
     void* tls_ctx;
+    uint64_t crypto_rotation_sequence;
+    uint64_t crypto_rotation_time_guard;
     struct st_ptls_buffer_t* tls_sendbuf;
     uint16_t psk_cipher_suite_id;
 
diff --git a/picoquic/quicctx.c b/picoquic/quicctx.c
index e2b2dd9c..d467060e 100644
--- a/picoquic/quicctx.c
+++ b/picoquic/quicctx.c
@@ -1810,6 +1810,19 @@ int picoquic_connection_error(picoquic_cnx_t* cnx, uint16_t local_error, uint64_
     return PICOQUIC_ERROR_DETECTED;
 }
 
+int picoquic_start_key_rotation(picoquic_cnx_t* cnx)
+{
+    int ret = picoquic_compute_new_rotated_keys(cnx);
+
+    if (ret == 0) {
+        picoquic_apply_rotated_keys(cnx, 1);
+
+        picoquic_crypto_context_free(&cnx->crypto_context_old);
+    }
+
+    return ret;
+}
+
 void picoquic_delete_cnx(picoquic_cnx_t* cnx)
 {
     picoquic_stream_head* stream;
diff --git a/picoquic/sender.c b/picoquic/sender.c
index 8f7cf4f8..f5a4bbec 100644
--- a/picoquic/sender.c
+++ b/picoquic/sender.c
@@ -237,7 +237,7 @@ uint32_t picoquic_create_packet_header(
     /* Prepare the packet header */
     if (packet_type == picoquic_packet_1rtt_protected) {
         /* Create a short packet -- using 32 bit sequence numbers for now */
-        uint8_t K = (packet_type == picoquic_packet_1rtt_protected) ? 0 : 0x40;
+        uint8_t K = (cnx->key_phase_enc) ? 0x40 : 0;
         const uint8_t C = 0x30;
         uint8_t spin_vec = (uint8_t)(cnx->path[0]->spin_vec);
         uint8_t spin_bit = (uint8_t)((cnx->path[0]->current_spin) << 2);
diff --git a/picoquic/tls_api.c b/picoquic/tls_api.c
index f2c29a2f..d6e10de0 100644
--- a/picoquic/tls_api.c
+++ b/picoquic/tls_api.c
@@ -56,8 +56,8 @@ typedef struct st_picoquic_tls_ctx_t {
     uint8_t ext_received[128];
     size_t ext_received_length;
     int ext_received_return;
-    uint8_t app_secret_enc[PTLS_MAX_SECRET_SIZE];
-    uint8_t app_secret_dec[PTLS_MAX_SECRET_SIZE];
+    uint8_t app_secret_enc[PTLS_MAX_DIGEST_SIZE];
+    uint8_t app_secret_dec[PTLS_MAX_DIGEST_SIZE];
 } picoquic_tls_ctx_t;
 
 int picoquic_receive_transport_extensions(picoquic_cnx_t* cnx, int extension_mode,
@@ -831,10 +831,10 @@ int picoquic_setup_initial_traffic_keys(picoquic_cnx_t* cnx)
 static int picoquic_rotate_app_secret(ptls_cipher_suite_t * cipher, uint8_t * secret)
 {
     int ret = 0;
-    uint8_t new_secret[PTLS_MAX_SECRET_SIZE];
+    uint8_t new_secret[PTLS_MAX_DIGEST_SIZE];
 
     ret = ptls_hkdf_expand_label(cipher->hash, new_secret,
-        cipher->aead->ctr_cipher->key_size, ptls_iovec_init(secret, cipher->aead->ctr_cipher->key_size),
+        cipher->hash->digest_size, ptls_iovec_init(secret, cipher->hash->digest_size),
         PICOQUIC_LABEL_TRAFFIC_UPDATE, ptls_iovec_init(NULL, 0), PICOQUIC_LABEL_QUIC_BASE);
     if (ret == 0) {
         memcpy(secret, new_secret, cipher->aead->ctr_cipher->key_size);
@@ -857,7 +857,7 @@ size_t picoquic_get_app_secret_size(picoquic_cnx_t* cnx)
 
     ptls_cipher_suite_t * cipher = ptls_get_cipher(tls_ctx->tls);
 
-    return (cipher->aead->ctr_cipher->key_size);
+    return (cipher->hash->digest_size);
 }
 
 int picoquic_compute_new_rotated_keys(picoquic_cnx_t * cnx)
@@ -871,7 +871,7 @@ int picoquic_compute_new_rotated_keys(picoquic_cnx_t * cnx)
         cnx->crypto_context_new.aead_encrypt != NULL ||
         cnx->crypto_context_new.pn_dec != NULL ||
         cnx->crypto_context_new.pn_enc != NULL) {
-        ret = -1;
+        ret = PICOQUIC_ERROR_CANNOT_COMPUTE_KEY;
     }
 
     /* Recompute the secrets */
@@ -886,11 +886,50 @@ int picoquic_compute_new_rotated_keys(picoquic_cnx_t * cnx)
     if (ret == 0) {
         ret = picoquic_rotate_app_secret(cipher, tls_ctx->app_secret_dec);
     }
+
     if (ret == 0) {
         ret = picoquic_set_key_from_secret(cnx, cipher, 0, &cnx->crypto_context_new, tls_ctx->app_secret_dec);
     }
 
-    return 0;
+    return (ret == 0)?0: PICOQUIC_ERROR_CANNOT_COMPUTE_KEY;
+}
+
+void picoquic_apply_rotated_keys(picoquic_cnx_t * cnx, int is_enc)
+{
+    if (is_enc) {
+        if (cnx->crypto_context[3].aead_encrypt != NULL) {
+            ptls_aead_free((ptls_aead_context_t *)cnx->crypto_context[3].aead_encrypt);
+        }
+
+        if (cnx->crypto_context[3].pn_enc != NULL) {
+            ptls_cipher_free((ptls_cipher_context_t *)cnx->crypto_context[3].pn_enc);
+        }
+
+        cnx->crypto_context[3].aead_encrypt = cnx->crypto_context_new.aead_encrypt;
+        cnx->crypto_context_new.aead_encrypt = NULL;
+        cnx->crypto_context[3].pn_enc = cnx->crypto_context_new.pn_enc;
+        cnx->crypto_context_new.pn_enc = NULL;
+
+        cnx->key_phase_enc ^= 1;
+    }
+    else {
+        if (cnx->crypto_context_old.aead_decrypt != NULL) {
+            ptls_aead_free((ptls_aead_context_t *)cnx->crypto_context_old.aead_decrypt);
+        }
+
+        if (cnx->crypto_context_old.pn_dec != NULL) {
+            ptls_cipher_free((ptls_cipher_context_t *)cnx->crypto_context_old.pn_dec);
+        }
+
+        cnx->crypto_context_old.aead_decrypt = cnx->crypto_context[3].aead_decrypt;
+        cnx->crypto_context[3].aead_decrypt = cnx->crypto_context_new.aead_decrypt;
+        cnx->crypto_context_new.aead_decrypt = NULL;
+        cnx->crypto_context_old.pn_dec = cnx->crypto_context[3].pn_dec;
+        cnx->crypto_context[3].pn_dec = cnx->crypto_context_new.pn_dec;
+        cnx->crypto_context_new.pn_dec = NULL;
+
+        cnx->key_phase_dec ^= 1;
+    }
 }
 
 /*
diff --git a/picoquic/tls_api.h b/picoquic/tls_api.h
index 7a01f6e4..92f29234 100644
--- a/picoquic/tls_api.h
+++ b/picoquic/tls_api.h
@@ -98,6 +98,7 @@ int picoquic_setup_initial_traffic_keys(picoquic_cnx_t* cnx);
 uint8_t * picoquic_get_app_secret(picoquic_cnx_t* cnx, int is_enc);
 size_t picoquic_get_app_secret_size(picoquic_cnx_t* cnx);
 int picoquic_compute_new_rotated_keys(picoquic_cnx_t * cnx);
+void picoquic_apply_rotated_keys(picoquic_cnx_t * cnx, int is_enc);
 
 void picoquic_crypto_context_free(picoquic_crypto_context_t * ctx);
 
diff --git a/picoquic_t/picoquic_t.c b/picoquic_t/picoquic_t.c
index 2d97918a..3de75076 100644
--- a/picoquic_t/picoquic_t.c
+++ b/picoquic_t/picoquic_t.c
@@ -138,7 +138,8 @@ static const picoquic_test_def_t test_table[] = {
     { "retire_cnxid", retire_cnxid_test },
     { "server_busy", server_busy_test },
     { "initial_close", initial_close_test },
-    { "new_rotated_key", new_rotated_key_test},
+    { "new_rotated_key", new_rotated_key_test },
+    { "key_rotation", key_rotation_test },
     { "stress", stress_test },
     { "fuzz", fuzz_test },
     { "fuzz_initial", fuzz_initial_test}
diff --git a/picoquictest/picoquictest.h b/picoquictest/picoquictest.h
index c3cfd67a..4a0ac220 100644
--- a/picoquictest/picoquictest.h
+++ b/picoquictest/picoquictest.h
@@ -137,6 +137,7 @@ int server_busy_test();
 int initial_close_test();
 int fuzz_initial_test();
 int new_rotated_key_test();
+int key_rotation_test();
 
 #ifdef __cplusplus
 }
diff --git a/picoquictest/tls_api_test.c b/picoquictest/tls_api_test.c
index 0972f41d..bf92b0d1 100644
--- a/picoquictest/tls_api_test.c
+++ b/picoquictest/tls_api_test.c
@@ -4294,6 +4294,36 @@ int initial_close_test()
  * Test that rotated keys are computed in a compatible way on client and server.
  */
 
+static int aead_iv_check(void * aead1, void * aead2)
+{
+    int ret = 0; 
+    ptls_aead_context_t *ctx1 = (ptls_aead_context_t *)aead1;
+    ptls_aead_context_t *ctx2 = (ptls_aead_context_t *)aead2;
+
+    if (memcmp(ctx1->static_iv, ctx2->static_iv, ctx1->algo->iv_size) != 0) {
+        ret = -1;
+    }
+    return;
+}
+
+
+static int pn_enc_check(void * pn1, void * pn2)
+{
+    int ret = 0;
+    uint8_t seed[16] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 };
+    uint8_t pn[4] = { 0, 1, 2 ,3 };
+    uint8_t pn_enc[4];
+    uint8_t pn_dec[4];
+
+    picoquic_pn_encrypt(pn1, seed, pn_enc, pn, 4);
+    picoquic_pn_encrypt(pn2, seed, pn_dec, pn_enc, 4);
+
+    if (memcmp(pn_dec, pn, 4) != 0) {
+        ret = -1;
+    }
+    return ret;
+}
+
 int new_rotated_key_test()
 {
     uint64_t loss_mask = 0;
@@ -4346,6 +4376,22 @@ int new_rotated_key_test()
                 DBG_PRINTF("Round %d. Server decryption secret does not match client encryption secret\n", i);
                 ret = -1;
             }
+            else if (aead_iv_check(test_ctx->cnx_server->crypto_context_new.aead_encrypt, test_ctx->cnx_client->crypto_context_new.aead_decrypt) != 0) {
+                DBG_PRINTF("Round %d. Client AEAD decryption does not match server AEAD encryption.\n", i);
+                ret = -1;
+            }
+            else if (aead_iv_check(test_ctx->cnx_client->crypto_context_new.aead_encrypt, test_ctx->cnx_server->crypto_context_new.aead_decrypt) != 0) {
+                DBG_PRINTF("Round %d. Server AEAD decryption does not match cliens AEAD encryption.\n", i);
+                ret = -1;
+            }
+            else if (pn_enc_check(test_ctx->cnx_server->crypto_context_new.pn_enc, test_ctx->cnx_client->crypto_context_new.pn_dec) != 0) {
+                DBG_PRINTF("Round %d. Client PN decryption does not match server PN encryption.\n", i);
+                ret = -1;
+            }
+            else if (pn_enc_check(test_ctx->cnx_client->crypto_context_new.pn_enc, test_ctx->cnx_server->crypto_context_new.pn_dec) != 0) {
+                DBG_PRINTF("Round %d. Server PN decryption does not match client PN encryption.\n", i);
+                ret = -1;
+            }
         }
 
         picoquic_crypto_context_free(&test_ctx->cnx_server->crypto_context_new);
@@ -4357,5 +4403,116 @@ int new_rotated_key_test()
         test_ctx = NULL;
     }
 
+    return ret;
+}
+
+
+/*
+ * Key rotation tests
+ */
+
+
+int key_rotation_test()
+{
+    uint64_t loss_mask_data = 0;
+    uint64_t simulated_time = 0;
+    uint64_t next_time = 0;
+    uint64_t loss_mask = 0;
+    int nb_trials = 0;
+    int nb_inactive = 0;
+    int max_trials = 100000;
+    int nb_rotation = 0;
+    uint64_t rotation_sequence = 100;
+    picoquic_test_tls_api_ctx_t* test_ctx = NULL;
+    int ret = tls_api_init_ctx(&test_ctx, PICOQUIC_INTERNAL_TEST_VERSION_1,
+        PICOQUIC_TEST_SNI, PICOQUIC_TEST_ALPN, &simulated_time, NULL, 0, 0, 0);
+
+    if (ret == 0 && test_ctx == NULL) {
+        ret = PICOQUIC_ERROR_MEMORY;
+    }
+
+    if (ret == 0) {
+        ret = tls_api_connection_loop(test_ctx, &loss_mask, 0, &simulated_time);
+    }
+
+    /* Prepare to send data */
+    if (ret == 0) {
+        ret = test_api_init_send_recv_scenario(test_ctx, test_scenario_very_long, sizeof(test_scenario_very_long));
+    }
+
+    /* Perform a data sending loop, during which various key rotations are tried
+     * every 100 packets or so */
+
+    while (ret == 0 && nb_trials < max_trials && nb_inactive < 256 && test_ctx->cnx_client->cnx_state == picoquic_state_client_ready && test_ctx->cnx_server->cnx_state == picoquic_state_server_ready) {
+        int was_active = 0;
+
+        nb_trials++;
+
+        if (test_ctx->cnx_server->pkt_ctx[picoquic_packet_context_application].send_sequence > rotation_sequence &&
+            test_ctx->cnx_server->crypto_context_new.aead_decrypt == NULL &&
+            test_ctx->cnx_server->crypto_context_new.aead_encrypt == NULL && 
+            test_ctx->cnx_server->crypto_context_new.pn_enc == NULL && 
+            test_ctx->cnx_server->crypto_context_new.pn_dec == NULL) {
+            rotation_sequence = test_ctx->cnx_server->pkt_ctx[picoquic_packet_context_application].send_sequence + 100;
+            nb_rotation++;
+            switch (nb_rotation) {
+            case 1: /* Key rotation at the client */
+                ret = picoquic_start_key_rotation(test_ctx->cnx_client);
+                break;
+            case 2: /* Key rotation at the server */
+                ret = picoquic_start_key_rotation(test_ctx->cnx_server);
+                break;
+            case 3: /* Simultaneous key rotation at the client */
+                rotation_sequence += 1000000000;
+                ret = picoquic_start_key_rotation(test_ctx->cnx_client);
+                if (ret == 0) {
+                    ret = picoquic_start_key_rotation(test_ctx->cnx_server);
+                }
+                break;
+            default:
+                break;
+            }
+        }
+
+        ret = tls_api_one_sim_round(test_ctx, &simulated_time, 0, &was_active);
+
+        if (ret < 0)
+        {
+            break;
+        }
+
+        if (was_active) {
+            nb_inactive = 0;
+        }
+        else {
+            nb_inactive++;
+        }
+
+        if (test_ctx->test_finished) {
+            if (picoquic_is_cnx_backlog_empty(test_ctx->cnx_client) && picoquic_is_cnx_backlog_empty(test_ctx->cnx_server)) {
+                break;
+            }
+        }
+    }
+
+    if (ret == 0 && nb_rotation < 3) {
+        DBG_PRINTF("Only %d key rotations completed out of 3\n", nb_rotation);
+        ret = -1;
+    }
+
+    if (ret == 0) {
+        ret = tls_api_attempt_to_close(test_ctx, &simulated_time);
+
+        if (ret != 0)
+        {
+            DBG_PRINTF("Connection close returns %d\n", ret);
+        }
+    }
+
+    if (test_ctx != NULL) {
+        tls_api_delete_ctx(test_ctx);
+        test_ctx = NULL;
+    }
+
     return ret;
 }
\ No newline at end of file
-- 
GitLab