diff --git a/lib/ngtcp2_conn.c b/lib/ngtcp2_conn.c index 43ebf82e09c22907a6c21c042803517240a2514c..eb3f8d877968d795c21d11963a9da0c4b1b75f34 100644 --- a/lib/ngtcp2_conn.c +++ b/lib/ngtcp2_conn.c @@ -3003,7 +3003,21 @@ static ssize_t conn_recv_handshake_pkt(ngtcp2_conn *conn, const uint8_t *pkt, ngtcp2_log_rx_pkt_hd(&conn->log, &hd); - /* Do this after decryption succeeded */ + rv = conn_ensure_decrypt_buffer(conn, payloadlen); + if (rv != 0) { + return rv; + } + + nwrite = conn_decrypt_pkt(conn, conn->decrypt_buf.base, payloadlen, payload, + payloadlen, plain_hdpkt, hdpktlen, hd.pkt_num, ckm, + decrypt); + if (nwrite < 0) { + return (int)nwrite; + } + + payload = conn->decrypt_buf.base; + payloadlen = (size_t)nwrite; + if (conn->server) { switch (hd.type) { case NGTCP2_PKT_INITIAL: @@ -3037,21 +3051,6 @@ static ssize_t conn_recv_handshake_pkt(ngtcp2_conn *conn, const uint8_t *pkt, } } - rv = conn_ensure_decrypt_buffer(conn, payloadlen); - if (rv != 0) { - return rv; - } - - nwrite = conn_decrypt_pkt(conn, conn->decrypt_buf.base, payloadlen, payload, - payloadlen, plain_hdpkt, hdpktlen, hd.pkt_num, ckm, - decrypt); - if (nwrite < 0) { - return (int)nwrite; - } - - payload = conn->decrypt_buf.base; - payloadlen = (size_t)nwrite; - for (; payloadlen;) { nread = ngtcp2_pkt_decode_frame(fr, payload, payloadlen); if (nread < 0) {