From 004d80fccdfa563ddc78744648d907d0f00ae8f3 Mon Sep 17 00:00:00 2001
From: Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com>
Date: Sun, 26 Apr 2020 13:25:05 +0900
Subject: [PATCH] Send ACK for server Initial packet

---
 lib/ngtcp2_conn.c        | 190 ++++++++++++++++++++++-----------------
 tests/ngtcp2_conn_test.c |   9 +-
 2 files changed, 118 insertions(+), 81 deletions(-)

diff --git a/lib/ngtcp2_conn.c b/lib/ngtcp2_conn.c
index 9a91efd2..e94541f1 100644
--- a/lib/ngtcp2_conn.c
+++ b/lib/ngtcp2_conn.c
@@ -1536,14 +1536,28 @@ static int conn_should_pad_pkt(ngtcp2_conn *conn, uint8_t type, size_t left,
                                size_t early_datalen) {
   size_t min_payloadlen;
 
-  if (conn->server || type != NGTCP2_PKT_INITIAL) {
+  if (conn->server) {
     return 0;
   }
 
-  if (!conn->early.ckm || early_datalen == 0) {
+  if (type == NGTCP2_PKT_HANDSHAKE) {
+    return conn->in_pktns != NULL;
+  }
+
+  if (conn->hs_pktns->crypto.tx.ckm &&
+      (conn->hs_pktns->rtb.probe_pkt_left ||
+       ngtcp2_ksl_len(&conn->hs_pktns->crypto.tx.frq) ||
+       !ngtcp2_acktr_empty(&conn->hs_pktns->acktr))) {
+    /* If we have something to send in Handshake packet, then add
+       PADDING in Handshake packet. */
+    min_payloadlen = 128;
+  } else if (!conn->early.ckm || early_datalen == 0) {
     return 1;
+  } else {
+    /* If we have something to send in 0RTT packet, then add PADDING
+       in 0RTT packet. */
+    min_payloadlen = ngtcp2_min(early_datalen, 128);
   }
-  min_payloadlen = ngtcp2_min(early_datalen, 128);
 
   return left <
          /* TODO Assuming that pkt_num is encoded in 1 byte. */
@@ -1645,6 +1659,25 @@ static ngtcp2_ssize conn_write_handshake_pkt(ngtcp2_conn *conn, uint8_t *dest,
     return 0;
   }
 
+  rv = conn_create_ack_frame(conn, &ackfr, &pktns->acktr, type, ts,
+                             /* ack_delay = */ 0,
+                             NGTCP2_DEFAULT_ACK_DELAY_EXPONENT);
+  if (rv != 0) {
+    ngtcp2_frame_chain_list_del(frq, conn->mem);
+    return rv;
+  }
+
+  if (ackfr) {
+    rv = conn_ppe_write_frame_hd_log(conn, &ppe, &hd_logged, &hd, ackfr);
+    if (rv != 0) {
+      assert(NGTCP2_ERR_NOBUF == rv);
+    } else {
+      ngtcp2_acktr_commit_ack(&pktns->acktr);
+      ngtcp2_acktr_add_ack(&pktns->acktr, hd.pkt_num, ackfr->ack.largest_ack);
+      pkt_empty = 0;
+    }
+  }
+
   for (; ngtcp2_ksl_len(&pktns->crypto.tx.frq);) {
     left = ngtcp2_ppe_left(&ppe);
     left = ngtcp2_pkt_crypto_max_datalen(
@@ -1678,25 +1711,6 @@ static ngtcp2_ssize conn_write_handshake_pkt(ngtcp2_conn *conn, uint8_t *dest,
         NGTCP2_RTB_FLAG_ACK_ELICITING | NGTCP2_RTB_FLAG_CRYPTO_PKT;
   }
 
-  rv = conn_create_ack_frame(conn, &ackfr, &pktns->acktr, type, ts,
-                             /* ack_delay = */ 0,
-                             NGTCP2_DEFAULT_ACK_DELAY_EXPONENT);
-  if (rv != 0) {
-    ngtcp2_frame_chain_list_del(frq, conn->mem);
-    return rv;
-  }
-
-  if (ackfr) {
-    rv = conn_ppe_write_frame_hd_log(conn, &ppe, &hd_logged, &hd, ackfr);
-    if (rv != 0) {
-      assert(NGTCP2_ERR_NOBUF == rv);
-    } else {
-      ngtcp2_acktr_commit_ack(&pktns->acktr);
-      ngtcp2_acktr_add_ack(&pktns->acktr, hd.pkt_num, ackfr->ack.largest_ack);
-      pkt_empty = 0;
-    }
-  }
-
   /* Don't send any PING frame if client Initial has not been
      acknowledged yet. */
   if (!(rtb_entry_flags & NGTCP2_RTB_FLAG_ACK_ELICITING) &&
@@ -1864,6 +1878,47 @@ static ngtcp2_ssize conn_write_ack_pkt(ngtcp2_conn *conn, uint8_t *dest,
                                             NGTCP2_RTB_FLAG_NONE, ts);
 }
 
+static void conn_discard_pktns(ngtcp2_conn *conn, ngtcp2_pktns *pktns) {
+  uint64_t bytes_in_flight;
+
+  bytes_in_flight = ngtcp2_rtb_get_bytes_in_flight(&pktns->rtb);
+
+  assert(conn->ccs.bytes_in_flight >= bytes_in_flight);
+
+  conn->ccs.bytes_in_flight -= bytes_in_flight;
+  conn->rcs.pto_count = 0;
+  conn->rcs.last_tx_pkt_ts[pktns->rtb.pktns_id] = UINT64_MAX;
+  conn->rcs.loss_time[pktns->rtb.pktns_id] = UINT64_MAX;
+
+  pktns_del(pktns, conn->mem);
+}
+
+/*
+ * conn_discard_initial_state discards state for Initial packet number
+ * space.
+ */
+static void conn_discard_initial_state(ngtcp2_conn *conn) {
+  if (!conn->in_pktns) {
+    return;
+  }
+
+  conn_discard_pktns(conn, conn->in_pktns);
+  conn->in_pktns = NULL;
+}
+
+/*
+ * conn_discard_handshake_state discards state for Handshake packet
+ * number space.
+ */
+static void conn_discard_handshake_state(ngtcp2_conn *conn) {
+  if (!conn->hs_pktns) {
+    return;
+  }
+
+  conn_discard_pktns(conn, conn->hs_pktns);
+  conn->hs_pktns = NULL;
+}
+
 /*
  * conn_write_handshake_ack_pkts writes packets which contain ACK
  * frame only.  This function writes at most 2 packets for each
@@ -1874,9 +1929,12 @@ static ngtcp2_ssize conn_write_handshake_ack_pkts(ngtcp2_conn *conn,
                                                   ngtcp2_tstamp ts) {
   ngtcp2_ssize res = 0, nwrite = 0;
 
-  /* Client never send ACK for server Initial.  This is because once
-     it gets server Initial, it gets Handshake tx key and discards
-     Initial key. */
+  /* In the most cases, client sends ACK in conn_write_handshake_pkt.
+     This function is only called when it is CWND limited.  It is not
+     required for client to send ACK for server Initial.  This is
+     because once it gets server Initial, it gets Handshake tx key and
+     discards Initial key.  The only good reason to send ACK is give
+     server RTT measurement early. */
   if (conn->server && conn->in_pktns) {
     nwrite = conn_write_ack_pkt(conn, dest, destlen, NGTCP2_PKT_INITIAL, ts);
     if (nwrite < 0) {
@@ -1895,7 +1953,12 @@ static ngtcp2_ssize conn_write_handshake_ack_pkts(ngtcp2_conn *conn,
       assert(nwrite != NGTCP2_ERR_NOBUF);
       return nwrite;
     }
+
     res += nwrite;
+
+    if (!conn->server && nwrite) {
+      conn_discard_initial_state(conn);
+    }
   }
 
   return res;
@@ -1949,21 +2012,24 @@ static ngtcp2_ssize conn_write_handshake_pkts(ngtcp2_conn *conn, uint8_t *dest,
   ngtcp2_ssize nwrite;
   ngtcp2_ssize res = 0;
 
-  nwrite = conn_write_handshake_pkt(conn, dest, destlen, NGTCP2_PKT_INITIAL,
-                                    early_datalen, ts);
-  if (nwrite < 0) {
-    assert(nwrite != NGTCP2_ERR_NOBUF);
-    return nwrite;
-  }
+  if (!conn->server && conn->hs_pktns->crypto.tx.ckm &&
+      !ngtcp2_acktr_empty(&conn->hs_pktns->acktr)) {
+    /* Discard Initial state here so that Handshake packet is not
+       padded. */
+    conn_discard_initial_state(conn);
+  } else {
+    nwrite = conn_write_handshake_pkt(conn, dest, destlen, NGTCP2_PKT_INITIAL,
+                                      early_datalen, ts);
+    if (nwrite < 0) {
+      assert(nwrite != NGTCP2_ERR_NOBUF);
+      return nwrite;
+    }
 
-  if (!conn->server && nwrite) {
-    return nwrite;
+    res += nwrite;
+    dest += nwrite;
+    destlen -= (size_t)nwrite;
   }
 
-  res += nwrite;
-  dest += nwrite;
-  destlen -= (size_t)nwrite;
-
   nwrite = conn_write_handshake_pkt(conn, dest, destlen, NGTCP2_PKT_HANDSHAKE,
                                     0, ts);
   if (nwrite < 0) {
@@ -1973,6 +2039,13 @@ static ngtcp2_ssize conn_write_handshake_pkts(ngtcp2_conn *conn, uint8_t *dest,
 
   res += nwrite;
 
+  if (!conn->server && conn->hs_pktns->crypto.tx.ckm && nwrite) {
+    /* We don't need to send further Initial packet if we have
+       Handshake key and sent something with it.  So discard initial
+       state here. */
+    conn_discard_initial_state(conn);
+  }
+
   return res;
 }
 
@@ -4015,47 +4088,6 @@ static int pktns_commit_recv_pkt_num(ngtcp2_pktns *pktns, int64_t pkt_num) {
   return 0;
 }
 
-static void conn_discard_pktns(ngtcp2_conn *conn, ngtcp2_pktns *pktns) {
-  uint64_t bytes_in_flight;
-
-  bytes_in_flight = ngtcp2_rtb_get_bytes_in_flight(&pktns->rtb);
-
-  assert(conn->ccs.bytes_in_flight >= bytes_in_flight);
-
-  conn->ccs.bytes_in_flight -= bytes_in_flight;
-  conn->rcs.pto_count = 0;
-  conn->rcs.last_tx_pkt_ts[pktns->rtb.pktns_id] = UINT64_MAX;
-  conn->rcs.loss_time[pktns->rtb.pktns_id] = UINT64_MAX;
-
-  pktns_del(pktns, conn->mem);
-}
-
-/*
- * conn_discard_initial_state discards state for Initial packet number
- * space.
- */
-static void conn_discard_initial_state(ngtcp2_conn *conn) {
-  if (!conn->in_pktns) {
-    return;
-  }
-
-  conn_discard_pktns(conn, conn->in_pktns);
-  conn->in_pktns = NULL;
-}
-
-/*
- * conn_discard_handshake_state discards state for Handshake packet
- * number space.
- */
-static void conn_discard_handshake_state(ngtcp2_conn *conn) {
-  if (!conn->hs_pktns) {
-    return;
-  }
-
-  conn_discard_pktns(conn, conn->hs_pktns);
-  conn->hs_pktns = NULL;
-}
-
 /*
  * verify_token verifies |hd| contains |token| in its token field.  It
  * returns 0 if it succeeds, or NGTCP2_ERR_PROTO.
@@ -7093,8 +7125,6 @@ int ngtcp2_conn_read_handshake(ngtcp2_conn *conn, const ngtcp2_path *path,
       if (rv != 0) {
         return rv;
       }
-
-      conn_discard_initial_state(conn);
     }
 
     return 0;
diff --git a/tests/ngtcp2_conn_test.c b/tests/ngtcp2_conn_test.c
index 5693db2c..6312b40e 100644
--- a/tests/ngtcp2_conn_test.c
+++ b/tests/ngtcp2_conn_test.c
@@ -4767,7 +4767,14 @@ void test_ngtcp2_conn_handshake_probe(void) {
   ent = ngtcp2_ksl_it_get(&it);
 
   CU_ASSERT(ent->flags & NGTCP2_RTB_FLAG_PROBE);
-  CU_ASSERT(sizeof(buf) > ent->pktlen);
+  /* We should expect sizeof(buf) > ent->pktlen, but we haven't
+     discarded Initial state in this test case, therefore Handshake
+     packet is padded.  In practice, client gets Initial from server,
+     which produces Handshake tx key.  Client acknowledges server
+     Initial.  After that client no longer needs to send Initial
+     packet, so initial state is discarded.  Then no more padding is
+     involved. */
+  CU_ASSERT(sizeof(buf) == ent->pktlen);
 
   ngtcp2_conn_del(conn);
 }
-- 
GitLab