From 1aaa247ef3baa2db732d203d32032d4d18f5e33f Mon Sep 17 00:00:00 2001
From: Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com>
Date: Sun, 3 Feb 2019 16:41:55 +0900
Subject: [PATCH] Don't regenerate TLS context on Retry

---
 examples/client.cc           |  5 +---
 lib/includes/ngtcp2/ngtcp2.h |  7 +++--
 lib/ngtcp2_conn.c            | 54 +++++++++++++++++++-----------------
 tests/ngtcp2_conn_test.c     |  6 ++--
 4 files changed, 38 insertions(+), 34 deletions(-)

diff --git a/examples/client.cc b/examples/client.cc
index ad18df43..d8442611 100644
--- a/examples/client.cc
+++ b/examples/client.cc
@@ -1623,10 +1623,7 @@ ssize_t Client::hp_mask(uint8_t *dest, size_t destlen, const uint8_t *key,
                          samplelen);
 }
 
-void Client::on_recv_retry() {
-  init_ssl();
-  setup_initial_crypto_context();
-}
+void Client::on_recv_retry() { setup_initial_crypto_context(); }
 
 namespace {
 int bind_addr(Address &local_addr, int fd, int family) {
diff --git a/lib/includes/ngtcp2/ngtcp2.h b/lib/includes/ngtcp2/ngtcp2.h
index 4bdb5207..a0fb3a09 100644
--- a/lib/includes/ngtcp2/ngtcp2.h
+++ b/lib/includes/ngtcp2/ngtcp2.h
@@ -1094,8 +1094,11 @@ typedef int (*ngtcp2_recv_version_negotiation)(ngtcp2_conn *conn,
  * :type:`ngtcp2_recv_retry` is invoked when Retry packet is received.
  * This callback is client only.
  *
- * Application must reinitialize TLS stack in order to start a fresh
- * cryptographic handshake.
+ * Application must regenerate packet protection key, IV, and header
+ * protection key for Initial packets using the destination connection
+ * ID obtained by `ngtcp2_conn_get_dcid()` and install them by calling
+ * `ngtcp2_conn_install_initial_tx_keys()` and
+ * `ngtcp2_conn_install_initial_rx_keys()`.
  *
  * 0-RTT data accepted by the ngtcp2 library will be retransmitted by
  * the library automatically.
diff --git a/lib/ngtcp2_conn.c b/lib/ngtcp2_conn.c
index b1312a34..41c73eca 100644
--- a/lib/ngtcp2_conn.c
+++ b/lib/ngtcp2_conn.c
@@ -3281,6 +3281,7 @@ static int conn_on_retry(ngtcp2_conn *conn, const ngtcp2_pkt_hd *hd,
   ngtcp2_pkt_retry retry;
   uint8_t *p;
   ngtcp2_rtb *rtb = &conn->pktns.rtb;
+  ngtcp2_rtb *in_rtb = &conn->in_pktns.rtb;
   uint8_t cidbuf[sizeof(retry.odcid.data) * 2 + 1];
   ngtcp2_frame_chain *frc = NULL;
 
@@ -3318,11 +3319,6 @@ static int conn_on_retry(ngtcp2_conn *conn, const ngtcp2_pkt_hd *hd,
 
   /* Just freeing memory is dangerous because we might free twice. */
 
-  ngtcp2_vec_del(conn->early_hp, conn->mem);
-  conn->early_hp = NULL;
-  ngtcp2_crypto_km_del(conn->early_ckm, conn->mem);
-  conn->early_ckm = NULL;
-
   rv = ngtcp2_rtb_remove_all(rtb, &frc);
   if (rv != 0) {
     assert(ngtcp2_err_is_fatal(rv));
@@ -3337,18 +3333,20 @@ static int conn_on_retry(ngtcp2_conn *conn, const ngtcp2_pkt_hd *hd,
     return rv;
   }
 
-  conn->pktns.last_tx_pkt_num = (uint64_t)-1;
-  conn->pktns.crypto_tx_offset = 0;
-  ngtcp2_rtb_clear(&conn->pktns.rtb);
-
-  conn->in_pktns.last_tx_pkt_num = (uint64_t)-1;
-  conn->in_pktns.crypto_tx_offset = 0;
-  ngtcp2_rtb_clear(&conn->in_pktns.rtb);
-
-  ngtcp2_frame_chain_list_del(conn->in_pktns.frq, conn->mem);
-  conn->in_pktns.frq = NULL;
+  frc = NULL;
+  rv = ngtcp2_rtb_remove_all(in_rtb, &frc);
+  if (rv != 0) {
+    assert(ngtcp2_err_is_fatal(rv));
+    ngtcp2_frame_chain_list_del(frc, conn->mem);
+    return rv;
+  }
 
-  conn->crypto.tx_offset = 0;
+  rv = conn_resched_frames(conn, &conn->in_pktns, &frc);
+  if (rv != 0) {
+    assert(ngtcp2_err_is_fatal(rv));
+    ngtcp2_frame_chain_list_del(frc, conn->mem);
+    return rv;
+  }
 
   assert(conn->token.begin == NULL);
 
@@ -6478,7 +6476,7 @@ static ssize_t conn_write_handshake(ngtcp2_conn *conn, uint8_t *dest,
                                     size_t destlen, size_t early_datalen,
                                     ngtcp2_tstamp ts) {
   int rv;
-  ssize_t res = 0, nwrite, early_spktlen = 0;
+  ssize_t res = 0, nwrite = 0, early_spktlen = 0;
   uint64_t cwnd;
   size_t origlen = destlen;
   size_t server_hs_tx_left;
@@ -6501,9 +6499,18 @@ static ssize_t conn_write_handshake(ngtcp2_conn *conn, uint8_t *dest,
       early_datalen = pending_early_datalen;
     }
 
-    nwrite = conn_write_client_initial(conn, dest, destlen, early_datalen, ts);
-    if (nwrite <= 0) {
-      return nwrite;
+    if (!(conn->flags & NGTCP2_CONN_FLAG_RECV_RETRY)) {
+      nwrite =
+          conn_write_client_initial(conn, dest, destlen, early_datalen, ts);
+      if (nwrite <= 0) {
+        return nwrite;
+      }
+    } else {
+      nwrite = conn_write_handshake_pkt(conn, dest, destlen, NGTCP2_PKT_INITIAL,
+                                        early_datalen, ts);
+      if (nwrite < 0) {
+        return nwrite;
+      }
     }
 
     if (pending_early_datalen) {
@@ -6511,11 +6518,8 @@ static ssize_t conn_write_handshake(ngtcp2_conn *conn, uint8_t *dest,
                                                   destlen - (size_t)nwrite, ts);
 
       if (early_spktlen < 0) {
-        if (ngtcp2_err_is_fatal((int)early_spktlen)) {
-          return early_spktlen;
-        }
-        conn->state = NGTCP2_CS_CLIENT_WAIT_HANDSHAKE;
-        return nwrite;
+        assert(ngtcp2_err_is_fatal((int)early_spktlen));
+        return early_spktlen;
       }
     }
 
diff --git a/tests/ngtcp2_conn_test.c b/tests/ngtcp2_conn_test.c
index b625146b..c620db5a 100644
--- a/tests/ngtcp2_conn_test.c
+++ b/tests/ngtcp2_conn_test.c
@@ -2063,7 +2063,7 @@ void test_ngtcp2_conn_recv_retry(void) {
       CU_ASSERT(spktlen == 0);
     } else {
       CU_ASSERT(spktlen > 0);
-      CU_ASSERT(0 == conn->in_pktns.last_tx_pkt_num);
+      CU_ASSERT(1 == conn->in_pktns.last_tx_pkt_num);
       CU_ASSERT(ngtcp2_cid_eq(&dcid, ngtcp2_conn_get_dcid(conn)));
       CU_ASSERT(conn->flags & NGTCP2_CONN_FLAG_RECV_RETRY);
     }
@@ -2134,7 +2134,7 @@ void test_ngtcp2_conn_recv_retry(void) {
   spktlen = ngtcp2_conn_write_handshake(conn, buf, sizeof(buf), ++t);
 
   CU_ASSERT(spktlen > 219 + 119);
-  CU_ASSERT(0 == conn->pktns.last_tx_pkt_num);
+  CU_ASSERT(2 == conn->pktns.last_tx_pkt_num);
 
   strm = ngtcp2_conn_find_stream(conn, stream_id);
 
@@ -2145,7 +2145,7 @@ void test_ngtcp2_conn_recv_retry(void) {
                                      stream_id, 0, null_data, 120, ++t);
 
   CU_ASSERT(spktlen > 0);
-  CU_ASSERT(1 == conn->pktns.last_tx_pkt_num);
+  CU_ASSERT(3 == conn->pktns.last_tx_pkt_num);
   CU_ASSERT(120 == datalen);
   CU_ASSERT(NULL == conn->pktns.frq);
   CU_ASSERT(!ngtcp2_rtb_empty(&conn->pktns.rtb));
-- 
GitLab