From 19d94d546b30bc8ccef52636e90361f83844bd3c Mon Sep 17 00:00:00 2001
From: Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com>
Date: Wed, 18 Apr 2018 22:31:14 +0900
Subject: [PATCH] Encode payload length

---
 lib/ngtcp2_conn.c          | 11 +++++++----
 lib/ngtcp2_conv.c          | 11 +++++++++++
 lib/ngtcp2_conv.h          |  8 ++++++++
 lib/ngtcp2_pkt.c           |  6 +++---
 lib/ngtcp2_ppe.c           |  7 +++++++
 lib/ngtcp2_ppe.h           |  2 ++
 tests/ngtcp2_conn_test.c   | 26 +++++++++++++-------------
 tests/ngtcp2_test_helper.c | 24 ++++++++++++++++--------
 tests/ngtcp2_test_helper.h |  4 ++--
 9 files changed, 69 insertions(+), 30 deletions(-)

diff --git a/lib/ngtcp2_conn.c b/lib/ngtcp2_conn.c
index 4fa3205e..ad1f01a9 100644
--- a/lib/ngtcp2_conn.c
+++ b/lib/ngtcp2_conn.c
@@ -1184,8 +1184,7 @@ static ssize_t conn_write_handshake_pkt(ngtcp2_conn *conn, uint8_t *dest,
   pfrc = &frc_head;
 
   ngtcp2_pkt_hd_init(&hd, NGTCP2_PKT_FLAG_LONG_FORM, type, &conn->dcid,
-                     &conn->scid, conn->last_tx_pkt_num + 1, conn->version,
-                     /* TODO */ 256);
+                     &conn->scid, conn->last_tx_pkt_num + 1, conn->version, 0);
 
   ctx.ckm = conn->hs_tx_ckm;
   ctx.aead_overhead = NGTCP2_HANDSHAKE_AEAD_OVERHEAD;
@@ -1436,7 +1435,7 @@ static ssize_t conn_write_handshake_ack_pkt(ngtcp2_conn *conn, uint8_t *dest,
 
   ngtcp2_pkt_hd_init(&hd, NGTCP2_PKT_FLAG_LONG_FORM, NGTCP2_PKT_HANDSHAKE,
                      &conn->dcid, &conn->scid, conn->last_tx_pkt_num + 1,
-                     conn->version, /* TODO */ 256);
+                     conn->version, 0);
 
   ctx.ckm = conn->hs_tx_ckm;
   ctx.aead_overhead = NGTCP2_HANDSHAKE_AEAD_OVERHEAD;
@@ -2742,6 +2741,10 @@ static int conn_recv_handshake_pkt(ngtcp2_conn *conn, const uint8_t *pkt,
   if (nread < 0) {
     return (int)nread;
   }
+  /* TODO support compound packet */
+  if (pktlen != (size_t)nread + hd.payloadlen) {
+    return NGTCP2_ERR_PROTO;
+  }
 
   if (conn->server && conn->early_ckm && ngtcp2_cid_eq(&conn->rcid, &hd.dcid) &&
       hd.type == NGTCP2_PKT_0RTT_PROTECTED) {
@@ -4668,7 +4671,7 @@ ssize_t ngtcp2_conn_write_stream(ngtcp2_conn *conn, uint8_t *dest,
   ctx.ckm = conn->early_ckm;
 
   ngtcp2_pkt_hd_init(&hd, pkt_flags, pkt_type, &conn->rcid, &conn->scid,
-                     conn->last_tx_pkt_num + 1, conn->version, /* TODO */ 256);
+                     conn->last_tx_pkt_num + 1, conn->version, 0);
 
   ctx.aead_overhead = conn->aead_overhead;
   ctx.encrypt = conn->callbacks.encrypt;
diff --git a/lib/ngtcp2_conv.c b/lib/ngtcp2_conv.c
index 98a275ae..533e61e9 100644
--- a/lib/ngtcp2_conv.c
+++ b/lib/ngtcp2_conv.c
@@ -144,6 +144,17 @@ uint8_t *ngtcp2_put_varint(uint8_t *p, uint64_t n) {
   return rv;
 }
 
+uint8_t *ngtcp2_put_varint14(uint8_t *p, uint16_t n) {
+  uint8_t *rv;
+
+  assert(n < 16384);
+
+  rv = ngtcp2_put_uint16be(p, n);
+  *p |= 0x40;
+
+  return rv;
+}
+
 size_t ngtcp2_get_varint_len(const uint8_t *p) {
   return varintlen_def[*p >> 6];
 }
diff --git a/lib/ngtcp2_conv.h b/lib/ngtcp2_conv.h
index 73c2796b..99969166 100644
--- a/lib/ngtcp2_conv.h
+++ b/lib/ngtcp2_conv.h
@@ -125,6 +125,14 @@ uint8_t *ngtcp2_put_uint16be(uint8_t *p, uint16_t n);
  */
 uint8_t *ngtcp2_put_varint(uint8_t *p, uint64_t n);
 
+/*
+ * ngtcp2_put_varint14 writes |n| in |p| using variable-length integer
+ * encoding.  |n| must be strictly less than 16384.  The function
+ * always encodes |n| in 2 bytes.  It returns the one beyond of the
+ * last written position.
+ */
+uint8_t *ngtcp2_put_varint14(uint8_t *p, uint16_t n);
+
 /*
  * ngtcp2_get_varint_len returns the required number of bytes to read
  * variable-length integer starting at |p|.
diff --git a/lib/ngtcp2_pkt.c b/lib/ngtcp2_pkt.c
index 1d109f6a..1231aa25 100644
--- a/lib/ngtcp2_pkt.c
+++ b/lib/ngtcp2_pkt.c
@@ -122,7 +122,6 @@ ssize_t ngtcp2_pkt_decode_hd_long(ngtcp2_pkt_hd *dest, const uint8_t *pkt,
   p += scil;
 
   dest->payloadlen = ngtcp2_get_varint(&n, p);
-
   p += n;
 
   dest->pkt_num = ngtcp2_get_uint32(p);
@@ -211,7 +210,8 @@ ssize_t ngtcp2_pkt_encode_hd_long(uint8_t *out, size_t outlen,
                                   const ngtcp2_pkt_hd *hd) {
   uint8_t *p;
   size_t len = NGTCP2_MIN_LONG_HEADERLEN + hd->dcid.datalen + hd->scid.datalen +
-               ngtcp2_put_varint_len(hd->payloadlen) - 1;
+               2 - 1 /* NGTCP2_MIN_LONG_HEADERLEN includes 1 byte for
+                        payloadlen */;
 
   if (outlen < len) {
     return NGTCP2_ERR_NOBUF;
@@ -237,7 +237,7 @@ ssize_t ngtcp2_pkt_encode_hd_long(uint8_t *out, size_t outlen,
   if (hd->scid.datalen) {
     p = ngtcp2_cpymem(p, hd->scid.data, hd->scid.datalen);
   }
-  p = ngtcp2_put_varint(p, hd->payloadlen);
+  p = ngtcp2_put_varint14(p, (uint16_t)hd->payloadlen);
   p = ngtcp2_put_uint32be(p, (uint32_t)hd->pkt_num);
 
   assert((size_t)(p - out) == len);
diff --git a/lib/ngtcp2_ppe.c b/lib/ngtcp2_ppe.c
index c99e183c..e1fe74ba 100644
--- a/lib/ngtcp2_ppe.c
+++ b/lib/ngtcp2_ppe.c
@@ -37,6 +37,7 @@ void ngtcp2_ppe_init(ngtcp2_ppe *ppe, uint8_t *out, size_t outlen,
   ngtcp2_buf_init(&ppe->buf, out, outlen);
 
   ppe->hdlen = 0;
+  ppe->payloadlen_offset = 0;
   ppe->pkt_num = 0;
   ppe->ctx = cctx;
 }
@@ -51,6 +52,7 @@ int ngtcp2_ppe_encode_hd(ngtcp2_ppe *ppe, const ngtcp2_pkt_hd *hd) {
   }
 
   if (hd->flags & NGTCP2_PKT_FLAG_LONG_FORM) {
+    ppe->payloadlen_offset = 1 + 4 + 1 + hd->dcid.datalen + hd->scid.datalen;
     rv = ngtcp2_pkt_encode_hd_long(
         buf->last, ngtcp2_buf_left(buf) - ctx->aead_overhead, hd);
   } else {
@@ -99,6 +101,11 @@ ssize_t ngtcp2_ppe_final(ngtcp2_ppe *ppe, const uint8_t **ppkt) {
   size_t payloadlen = ngtcp2_buf_len(buf) - ppe->hdlen;
   size_t destlen = (size_t)(buf->end - buf->begin) - ppe->hdlen;
 
+  if (ppe->payloadlen_offset) {
+    ngtcp2_put_varint14(buf->begin + ppe->payloadlen_offset,
+                        (uint16_t)(payloadlen + ctx->aead_overhead));
+  }
+
   ngtcp2_crypto_create_nonce(ppe->nonce, ctx->ckm->iv, ctx->ckm->ivlen,
                              ppe->pkt_num);
 
diff --git a/lib/ngtcp2_ppe.h b/lib/ngtcp2_ppe.h
index 24176e17..842776ec 100644
--- a/lib/ngtcp2_ppe.h
+++ b/lib/ngtcp2_ppe.h
@@ -42,6 +42,8 @@ typedef struct {
   ngtcp2_crypto_ctx *ctx;
   /* hdlen is the number of bytes for packet header written in buf. */
   size_t hdlen;
+  /* payloadlen_offset is the offset to payload length field. */
+  size_t payloadlen_offset;
   /* pkt_num is the packet number written in buf. */
   uint64_t pkt_num;
   /* nonce is the buffer to store nonce.  It should be equal or longer
diff --git a/tests/ngtcp2_conn_test.c b/tests/ngtcp2_conn_test.c
index 7c9021df..27c85b99 100644
--- a/tests/ngtcp2_conn_test.c
+++ b/tests/ngtcp2_conn_test.c
@@ -1699,7 +1699,7 @@ void test_ngtcp2_conn_recv_server_stateless_retry(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_RETRY, &conn->scid, &conn->dcid,
+      conn, buf, sizeof(buf), NGTCP2_PKT_RETRY, &conn->scid, &conn->dcid,
       conn->last_tx_pkt_num, conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, 2);
@@ -1732,7 +1732,7 @@ void test_ngtcp2_conn_recv_delayed_handshake_pkt(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid, 1,
+      conn, buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid, 1,
       NGTCP2_PROTO_VER_MAX, &fr);
   rv = ngtcp2_conn_recv(conn, buf, pktlen, 1);
 
@@ -1757,7 +1757,7 @@ void test_ngtcp2_conn_recv_delayed_handshake_pkt(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid, 1,
+      conn, buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid, 1,
       NGTCP2_PROTO_VER_MAX, &fr);
   rv = ngtcp2_conn_recv(conn, buf, pktlen, 1);
 
@@ -1777,7 +1777,7 @@ void test_ngtcp2_conn_recv_delayed_handshake_pkt(void) {
   fr.ack.num_blks = 0;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid, 1,
+      conn, buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid, 1,
       NGTCP2_PROTO_VER_MAX, &fr);
   rv = ngtcp2_conn_recv(conn, buf, pktlen, 1);
 
@@ -1840,7 +1840,7 @@ void test_ngtcp2_conn_handshake(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
+      conn, buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
       conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -1852,7 +1852,7 @@ void test_ngtcp2_conn_handshake(void) {
   memset(fr.path_response.data, 0, sizeof(fr.path_response));
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid,
+      conn, buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid,
       ++pkt_num, conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -1891,7 +1891,7 @@ void test_ngtcp2_conn_handshake_error(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid,
+      conn, buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid,
       ++pkt_num, conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -1911,7 +1911,7 @@ void test_ngtcp2_conn_handshake_error(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
+      conn, buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
       conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -1926,7 +1926,7 @@ void test_ngtcp2_conn_handshake_error(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid,
+      conn, buf, sizeof(buf), NGTCP2_PKT_HANDSHAKE, &conn->scid, &conn->dcid,
       ++pkt_num, conn->version, &fr);
 
   conn->callbacks.recv_stream0_data = recv_stream0_handshake_error;
@@ -2640,7 +2640,7 @@ void test_ngtcp2_conn_recv_early_data(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
+      conn, buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
       conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -2655,7 +2655,7 @@ void test_ngtcp2_conn_recv_early_data(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_0RTT_PROTECTED, &rcid, &conn->dcid,
+      conn, buf, sizeof(buf), NGTCP2_PKT_0RTT_PROTECTED, &rcid, &conn->dcid,
       ++pkt_num, conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -2680,7 +2680,7 @@ void test_ngtcp2_conn_recv_early_data(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_0RTT_PROTECTED, &rcid, &conn->dcid,
+      conn, buf, sizeof(buf), NGTCP2_PKT_0RTT_PROTECTED, &rcid, &conn->dcid,
       ++pkt_num, conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
@@ -2695,7 +2695,7 @@ void test_ngtcp2_conn_recv_early_data(void) {
   fr.stream.data = null_data;
 
   pktlen = write_single_frame_handshake_pkt(
-      buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
+      conn, buf, sizeof(buf), NGTCP2_PKT_INITIAL, &rcid, &conn->dcid, ++pkt_num,
       conn->version, &fr);
 
   spktlen = ngtcp2_conn_handshake(conn, buf, sizeof(buf), buf, pktlen, ++t);
diff --git a/tests/ngtcp2_test_helper.c b/tests/ngtcp2_test_helper.c
index c2492fc9..2cedf8cb 100644
--- a/tests/ngtcp2_test_helper.c
+++ b/tests/ngtcp2_test_helper.c
@@ -30,7 +30,6 @@
 #include "ngtcp2_conv.h"
 #include "ngtcp2_pkt.h"
 #include "ngtcp2_ppe.h"
-#include "ngtcp2_upe.h"
 
 size_t ngtcp2_t_encode_stream_frame(uint8_t *out, uint8_t flags,
                                     uint64_t stream_id, uint64_t offset,
@@ -155,25 +154,34 @@ size_t write_single_frame_pkt_without_conn_id(ngtcp2_conn *conn, uint8_t *out,
   return (size_t)n;
 }
 
-size_t write_single_frame_handshake_pkt(uint8_t *out, size_t outlen,
-                                        uint8_t pkt_type,
+size_t write_single_frame_handshake_pkt(ngtcp2_conn *conn, uint8_t *out,
+                                        size_t outlen, uint8_t pkt_type,
                                         const ngtcp2_cid *dcid,
                                         const ngtcp2_cid *scid,
                                         uint64_t pkt_num, uint32_t version,
                                         ngtcp2_frame *fr) {
-  ngtcp2_upe upe;
+  ngtcp2_crypto_ctx ctx;
+  ngtcp2_ppe ppe;
   ngtcp2_pkt_hd hd;
   int rv;
+  ssize_t n;
+
+  memset(&ctx, 0, sizeof(ctx));
+  ctx.encrypt = null_encrypt;
+  ctx.ckm = conn->hs_rx_ckm;
+  ctx.user_data = conn;
 
   ngtcp2_pkt_hd_init(&hd, NGTCP2_PKT_FLAG_LONG_FORM, pkt_type, dcid, scid,
                      pkt_num, version, 0);
 
-  ngtcp2_upe_init(&upe, out, outlen);
-  rv = ngtcp2_upe_encode_hd(&upe, &hd);
+  ngtcp2_ppe_init(&ppe, out, outlen, &ctx);
+  rv = ngtcp2_ppe_encode_hd(&ppe, &hd);
   assert(0 == rv);
-  rv = ngtcp2_upe_encode_frame(&upe, fr);
+  rv = ngtcp2_ppe_encode_frame(&ppe, fr);
   assert(0 == rv);
-  return ngtcp2_upe_final(&upe, NULL);
+  n = ngtcp2_ppe_final(&ppe, NULL);
+  assert(n > 0);
+  return (size_t)n;
 }
 
 ngtcp2_strm *open_stream(ngtcp2_conn *conn, uint64_t stream_id) {
diff --git a/tests/ngtcp2_test_helper.h b/tests/ngtcp2_test_helper.h
index 1c234040..7413013a 100644
--- a/tests/ngtcp2_test_helper.h
+++ b/tests/ngtcp2_test_helper.h
@@ -100,8 +100,8 @@ size_t write_single_frame_pkt_without_conn_id(ngtcp2_conn *conn, uint8_t *out,
  * capacity is |outlen|.  This function returns the number of bytes
  * written.
  */
-size_t write_single_frame_handshake_pkt(uint8_t *out, size_t outlen,
-                                        uint8_t pkt_type,
+size_t write_single_frame_handshake_pkt(ngtcp2_conn *conn, uint8_t *out,
+                                        size_t outlen, uint8_t pkt_type,
                                         const ngtcp2_cid *dcid,
                                         const ngtcp2_cid *scid,
                                         uint64_t pkt_num, uint32_t version,
-- 
GitLab