diff --git a/lib/ngtcp2_conn.c b/lib/ngtcp2_conn.c
index 43ebf82e09c22907a6c21c042803517240a2514c..eb3f8d877968d795c21d11963a9da0c4b1b75f34 100644
--- a/lib/ngtcp2_conn.c
+++ b/lib/ngtcp2_conn.c
@@ -3003,7 +3003,21 @@ static ssize_t conn_recv_handshake_pkt(ngtcp2_conn *conn, const uint8_t *pkt,
 
   ngtcp2_log_rx_pkt_hd(&conn->log, &hd);
 
-  /* Do this after decryption succeeded */
+  rv = conn_ensure_decrypt_buffer(conn, payloadlen);
+  if (rv != 0) {
+    return rv;
+  }
+
+  nwrite = conn_decrypt_pkt(conn, conn->decrypt_buf.base, payloadlen, payload,
+                            payloadlen, plain_hdpkt, hdpktlen, hd.pkt_num, ckm,
+                            decrypt);
+  if (nwrite < 0) {
+    return (int)nwrite;
+  }
+
+  payload = conn->decrypt_buf.base;
+  payloadlen = (size_t)nwrite;
+
   if (conn->server) {
     switch (hd.type) {
     case NGTCP2_PKT_INITIAL:
@@ -3037,21 +3051,6 @@ static ssize_t conn_recv_handshake_pkt(ngtcp2_conn *conn, const uint8_t *pkt,
     }
   }
 
-  rv = conn_ensure_decrypt_buffer(conn, payloadlen);
-  if (rv != 0) {
-    return rv;
-  }
-
-  nwrite = conn_decrypt_pkt(conn, conn->decrypt_buf.base, payloadlen, payload,
-                            payloadlen, plain_hdpkt, hdpktlen, hd.pkt_num, ckm,
-                            decrypt);
-  if (nwrite < 0) {
-    return (int)nwrite;
-  }
-
-  payload = conn->decrypt_buf.base;
-  payloadlen = (size_t)nwrite;
-
   for (; payloadlen;) {
     nread = ngtcp2_pkt_decode_frame(fr, payload, payloadlen);
     if (nread < 0) {