From e43a8c4e53d2bb7383cba4b395cac5d2c56d3fd7 Mon Sep 17 00:00:00 2001 From: Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com> Date: Sat, 23 Feb 2019 18:00:22 +0900 Subject: [PATCH] Server preferred address --- README.rst | 2 +- ci/build_openssl.sh | 2 +- examples/client.cc | 107 +++++++- examples/client.h | 4 + examples/server.cc | 507 +++++++++++++++++++++++------------ examples/server.h | 46 ++-- lib/includes/ngtcp2/ngtcp2.h | 25 +- lib/ngtcp2_addr.c | 5 +- lib/ngtcp2_cid.c | 4 +- lib/ngtcp2_conn.c | 124 ++++++++- lib/ngtcp2_conn.h | 4 + lib/ngtcp2_path.c | 8 +- lib/ngtcp2_pv.h | 3 - tests/ngtcp2_conn_test.c | 5 +- tests/ngtcp2_pv_test.c | 4 +- 15 files changed, 631 insertions(+), 219 deletions(-) diff --git a/README.rst b/README.rst index ca77434c..25b63965 100644 --- a/README.rst +++ b/README.rst @@ -52,7 +52,7 @@ Build from git .. code-block:: text - $ git clone --depth 1 -b quic-draft-17 https://github.com/tatsuhiro-t/openssl + $ git clone --depth 1 -b quic-draft-18 https://github.com/tatsuhiro-t/openssl $ cd openssl $ # For Linux $ ./config enable-tls1_3 --prefix=$PWD/build diff --git a/ci/build_openssl.sh b/ci/build_openssl.sh index 025502a2..c86cf7a1 100755 --- a/ci/build_openssl.sh +++ b/ci/build_openssl.sh @@ -2,7 +2,7 @@ #build last openssl master (for Travis) cd .. -git clone --depth 1 -b quic-draft-17 https://github.com/tatsuhiro-t/openssl +git clone --depth 1 -b quic-draft-18 https://github.com/tatsuhiro-t/openssl cd openssl ./config enable-tls1_3 --prefix=$PWD/build make -j$(nproc) diff --git a/examples/client.cc b/examples/client.cc index ad0076bd..bcd3f24f 100644 --- a/examples/client.cc +++ b/examples/client.cc @@ -793,6 +793,25 @@ int path_validation(ngtcp2_conn *conn, const ngtcp2_path *path, } } // namespace +namespace { +int select_preferred_address(ngtcp2_conn *conn, ngtcp2_addr *dest, + const ngtcp2_preferred_addr *paddr, + void *user_data) { + auto c = static_cast<Client *>(user_data); + Address addr; + + if (c->select_preferred_address(addr, paddr) != 0) { + dest->len = 0; + return 0; + } + + dest->len = addr.len; + memcpy(dest->addr, &addr.su, dest->len); + + return 0; +} +} // namespace + int Client::init_ssl() { if (ssl_) { SSL_free(ssl_); @@ -909,6 +928,7 @@ int Client::init(int fd, const Address &local_addr, const Address &remote_addr, remove_connection_id, ::update_key, path_validation, + ::select_preferred_address, }; auto dis = std::uniform_int_distribution<uint8_t>( @@ -1342,9 +1362,11 @@ int Client::on_write(bool retransmit) { return rv; } + PathStorage path; + for (;;) { - auto n = ngtcp2_conn_write_pkt(conn_, nullptr, sendbuf_.wpos(), max_pktlen_, - util::timestamp(loop_)); + auto n = ngtcp2_conn_write_pkt(conn_, &path.path, sendbuf_.wpos(), + max_pktlen_, util::timestamp(loop_)); if (n < 0) { std::cerr << "ngtcp2_conn_write_pkt: " << ngtcp2_strerror(n) << std::endl; disconnect(n); @@ -1356,6 +1378,8 @@ int Client::on_write(bool retransmit) { sendbuf_.push(n); + update_remote_addr(&path.path.remote); + auto rv = send_packet(); if (rv == NETWORK_ERR_SEND_NON_FATAL) { schedule_retransmit(); @@ -1408,10 +1432,11 @@ int Client::write_streams() { int Client::on_write_stream(uint64_t stream_id, uint8_t fin, Buffer &data) { ssize_t ndatalen; + PathStorage path; for (;;) { auto n = ngtcp2_conn_write_stream( - conn_, nullptr, sendbuf_.wpos(), max_pktlen_, &ndatalen, stream_id, fin, - data.rpos(), data.size(), util::timestamp(loop_)); + conn_, &path.path, sendbuf_.wpos(), max_pktlen_, &ndatalen, stream_id, + fin, data.rpos(), data.size(), util::timestamp(loop_)); if (n < 0) { switch (n) { case NGTCP2_ERR_EARLY_DATA_REJECTED: @@ -1437,6 +1462,8 @@ int Client::on_write_stream(uint64_t stream_id, uint8_t fin, Buffer &data) { sendbuf_.push(n); + update_remote_addr(&path.path.remote); + auto rv = send_packet(); if (rv != NETWORK_ERR_OK) { return rv; @@ -1744,7 +1771,7 @@ int Client::change_local_addr() { if (config.nat_rebinding) { ngtcp2_addr addr; ngtcp2_conn_set_local_addr( - conn_, ngtcp2_addr_init(&addr, &local_addr.su, local_addr.len)); + conn_, ngtcp2_addr_init(&addr, &local_addr.su, local_addr.len, NULL)); } else { auto path = ngtcp2_path{ {local_addr.len, reinterpret_cast<uint8_t *>(&local_addr.su)}, @@ -1872,6 +1899,11 @@ int Client::initiate_key_update() { return 0; } +void Client::update_remote_addr(const ngtcp2_addr *addr) { + remote_addr_.len = addr->len; + memcpy(&remote_addr_.su, addr->addr, addr->len); +} + int Client::send_packet() { if (debug::packet_lost(config.tx_loss_prob)) { if (!config.quiet) { @@ -2008,9 +2040,10 @@ int Client::handle_error(int liberr) { err_code = ngtcp2_err_infer_quic_transport_error_code(liberr); } - auto n = ngtcp2_conn_write_connection_close(conn_, nullptr, sendbuf_.wpos(), - max_pktlen_, err_code, - util::timestamp(loop_)); + PathStorage path; + auto n = ngtcp2_conn_write_connection_close(conn_, &path.path, + sendbuf_.wpos(), max_pktlen_, + err_code, util::timestamp(loop_)); if (n < 0) { std::cerr << "ngtcp2_conn_write_connection_close: " << ngtcp2_strerror(n) << std::endl; @@ -2019,6 +2052,8 @@ int Client::handle_error(int liberr) { sendbuf_.push(n); + update_remote_addr(&path.path.remote); + return send_packet(); } @@ -2197,6 +2232,62 @@ int Client::on_extend_max_streams() { return 0; } +int Client::select_preferred_address(Address &selected_addr, + const ngtcp2_preferred_addr *paddr) { + int af; + const uint8_t *binaddr; + uint16_t port; + constexpr uint8_t empty_addr[] = {0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0}; + if (local_addr_.su.sa.sa_family == AF_INET && + memcmp(empty_addr, paddr->ipv4_addr, sizeof(paddr->ipv4_addr)) != 0) { + af = AF_INET; + binaddr = paddr->ipv4_addr; + port = paddr->ipv4_port; + } else if (local_addr_.su.sa.sa_family == AF_INET6 && + memcmp(empty_addr, paddr->ipv6_addr, sizeof(paddr->ipv6_addr)) != + 0) { + af = AF_INET6; + binaddr = paddr->ipv6_addr; + port = paddr->ipv6_port; + } else { + return -1; + } + + char host[NI_MAXHOST]; + if (inet_ntop(af, binaddr, host, sizeof(host)) == NULL) { + std::cerr << "inet_ntop: " << strerror(errno) << std::endl; + return -1; + } + + if (!config.quiet) { + std::cerr << "selected server preferred_address is [" << host + << "]:" << port << std::endl; + } + + addrinfo hints{}; + addrinfo *res; + + hints.ai_flags = AI_NUMERICHOST | AI_NUMERICSERV; + hints.ai_family = af; + hints.ai_socktype = SOCK_DGRAM; + + auto rv = getaddrinfo(host, std::to_string(port).c_str(), &hints, &res); + if (rv != 0) { + std::cerr << "getaddrinfo: " << gai_strerror(rv) << std::endl; + return -1; + } + + assert(res); + + selected_addr.len = res->ai_addrlen; + memcpy(&selected_addr.su, res->ai_addr, res->ai_addrlen); + + freeaddrinfo(res); + + return 0; +} + void Client::start_wev() { ev_io_start(loop_, &wev_); } void Client::set_tls_alert(uint8_t alert) { tls_alert_ = alert; } diff --git a/examples/client.h b/examples/client.h index 62655c91..dd8baeb5 100644 --- a/examples/client.h +++ b/examples/client.h @@ -199,6 +199,7 @@ public: ssize_t hp_mask(uint8_t *data, size_t destlen, const uint8_t *key, size_t keylen, const uint8_t *sample, size_t samplelen); ngtcp2_conn *conn() const; + void update_remote_addr(const ngtcp2_addr *addr); int send_packet(); int start_interactive_input(); int send_interactive_input(); @@ -221,6 +222,9 @@ public: void set_tls_alert(uint8_t alert); + int select_preferred_address(Address &selected_addr, + const ngtcp2_preferred_addr *paddr); + private: Address local_addr_; Address remote_addr_; diff --git a/examples/server.cc b/examples/server.cc index cc8232c3..1b29e11d 100644 --- a/examples/server.cc +++ b/examples/server.cc @@ -658,7 +658,6 @@ void retransmitcb(struct ev_loop *loop, ev_timer *w, int revents) { case NETWORK_ERR_CLOSE_WAIT: return; case NETWORK_ERR_SEND_NON_FATAL: - s->start_wev(); return; default: s->remove(h); @@ -673,7 +672,6 @@ void retransmitcb(struct ev_loop *loop, ev_timer *w, int revents) { case NETWORK_ERR_CLOSE_WAIT: return; case NETWORK_ERR_SEND_NON_FATAL: - s->start_wev(); return; default: s->remove(h); @@ -685,16 +683,18 @@ void retransmitcb(struct ev_loop *loop, ev_timer *w, int revents) { Handler::Handler(struct ev_loop *loop, SSL_CTX *ssl_ctx, Server *server, const ngtcp2_cid *rcid) - : remote_addr_{}, + : endpoint_{nullptr}, + remote_addr_{}, max_pktlen_(0), loop_(loop), ssl_ctx_(ssl_ctx), ssl_(nullptr), server_(server), - fd_(-1), ncread_(0), shandshake_idx_(0), conn_(nullptr), + scid_{}, + pscid_{}, rcid_(*rcid), hs_crypto_ctx_{}, crypto_ctx_{}, @@ -990,11 +990,13 @@ int path_validation(ngtcp2_conn *conn, const ngtcp2_path *path, } } // namespace -int Handler::init(int fd, const sockaddr *sa, socklen_t salen, +int Handler::init(Endpoint &ep, const sockaddr *sa, socklen_t salen, const ngtcp2_cid *dcid, const ngtcp2_cid *ocid, uint32_t version) { int rv; + endpoint_ = &ep; + remote_addr_.len = salen; memcpy(&remote_addr_.su.sa, sa, salen); @@ -1009,7 +1011,6 @@ int Handler::init(int fd, const sockaddr *sa, socklen_t salen, return -1; } - fd_ = fd; ssl_ = SSL_new(ssl_ctx_); auto bio = BIO_new(create_bio_method()); BIO_set_data(bio, this); @@ -1073,10 +1074,37 @@ int Handler::init(int fd, const sockaddr *sa, socklen_t salen, std::generate(scid_.data, scid_.data + scid_.datalen, [&dis]() { return dis(randgen); }); - auto &local_addr = server_->get_local_addr(); + if (config.preferred_ipv4_addr.len || config.preferred_ipv6_addr.len) { + settings.preferred_address_present = 1; + if (config.preferred_ipv4_addr.len) { + auto &dest = settings.preferred_address.ipv4_addr; + const auto &addr = config.preferred_ipv4_addr; + assert(sizeof(dest) == sizeof(addr.su.in.sin_addr)); + memcpy(&dest, &addr.su.in.sin_addr, sizeof(dest)); + settings.preferred_address.ipv4_port = htons(addr.su.in.sin_port); + } + if (config.preferred_ipv6_addr.len) { + auto &dest = settings.preferred_address.ipv6_addr; + const auto &addr = config.preferred_ipv6_addr; + assert(sizeof(dest) == sizeof(addr.su.in6.sin6_addr)); + memcpy(&dest, &addr.su.in6.sin6_addr, sizeof(dest)); + settings.preferred_address.ipv6_port = htons(addr.su.in6.sin6_port); + } + + auto &token = settings.preferred_address.stateless_reset_token; + std::generate(std::begin(token), std::end(token), + [&dis]() { return dis(randgen); }); + + pscid_.datalen = NGTCP2_SV_SCIDLEN; + std::generate(pscid_.data, pscid_.data + pscid_.datalen, + [&dis]() { return dis(randgen); }); + settings.preferred_address.cid = pscid_; + } + auto path = ngtcp2_path{ - {local_addr.len, const_cast<uint8_t *>( - reinterpret_cast<const uint8_t *>(&local_addr.su))}, + {ep.addr.len, + const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(&ep.addr.su)), + &ep}, {salen, const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(sa))}}; rv = ngtcp2_conn_server_new(&conn_, dcid, &scid_, &path, version, &callbacks, &settings, this); @@ -1423,7 +1451,7 @@ ssize_t Handler::do_handshake_write_once() { sendbuf_.push(nwrite); - auto rv = server_->send_packet(remote_addr_, sendbuf_); + auto rv = server_->send_packet(*endpoint_, remote_addr_, sendbuf_); if (rv == NETWORK_ERR_SEND_NON_FATAL) { schedule_retransmit(); return rv; @@ -1442,7 +1470,7 @@ int Handler::do_handshake(const uint8_t *data, size_t datalen) { } if (sendbuf_.size() > 0) { - auto rv = server_->send_packet(remote_addr_, sendbuf_); + auto rv = server_->send_packet(*endpoint_, remote_addr_, sendbuf_); if (rv != NETWORK_ERR_OK) { return rv; } @@ -1459,21 +1487,25 @@ int Handler::do_handshake(const uint8_t *data, size_t datalen) { } } +void Handler::update_endpoint(const ngtcp2_addr *addr) { + endpoint_ = static_cast<Endpoint *>(addr->user_data); + assert(endpoint_); +} + void Handler::update_remote_addr(const ngtcp2_addr *addr) { remote_addr_.len = addr->len; - memcpy(&remote_addr_.su, addr->addr, sizeof(addr->len)); + memcpy(&remote_addr_.su, addr->addr, addr->len); } -int Handler::feed_data(const sockaddr *sa, socklen_t salen, uint8_t *data, - size_t datalen) { +int Handler::feed_data(Endpoint &ep, const sockaddr *sa, socklen_t salen, + uint8_t *data, size_t datalen) { int rv; if (ngtcp2_conn_get_handshake_completed(conn_)) { - auto &local_addr = server_->get_local_addr(); auto path = ngtcp2_path{ - {local_addr.len, - const_cast<uint8_t *>( - reinterpret_cast<const uint8_t *>(&local_addr.su))}, + {ep.addr.len, + const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(&ep.addr.su)), + &ep}, {salen, const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(sa))}}; rv = ngtcp2_conn_read_pkt(conn_, &path, data, datalen, util::timestamp(loop_)); @@ -1486,6 +1518,7 @@ int Handler::feed_data(const sockaddr *sa, socklen_t salen, uint8_t *data, return handle_error(rv); } } else { + // TODO Should we check that path is consistent during handshake? rv = do_handshake(data, datalen); if (rv != 0) { return handle_error(rv); @@ -1495,11 +1528,11 @@ int Handler::feed_data(const sockaddr *sa, socklen_t salen, uint8_t *data, return 0; } -int Handler::on_read(const sockaddr *sa, socklen_t salen, uint8_t *data, - size_t datalen) { +int Handler::on_read(Endpoint &ep, const sockaddr *sa, socklen_t salen, + uint8_t *data, size_t datalen) { int rv; - rv = feed_data(sa, salen, data, datalen); + rv = feed_data(ep, sa, salen, data, datalen); if (rv != 0) { return rv; } @@ -1517,7 +1550,7 @@ int Handler::on_write(bool retransmit) { } if (sendbuf_.size() > 0) { - auto rv = server_->send_packet(remote_addr_, sendbuf_); + auto rv = server_->send_packet(*endpoint_, remote_addr_, sendbuf_); if (rv != NETWORK_ERR_OK) { return rv; } @@ -1576,9 +1609,10 @@ int Handler::on_write(bool retransmit) { sendbuf_.push(n); + update_endpoint(&path.path.local); update_remote_addr(&path.path.remote); - auto rv = server_->send_packet(remote_addr_, sendbuf_); + auto rv = server_->send_packet(*endpoint_, remote_addr_, sendbuf_); if (rv == NETWORK_ERR_SEND_NON_FATAL) { schedule_retransmit(); return rv; @@ -1655,9 +1689,10 @@ int Handler::write_stream_data(Stream &stream, int fin, Buffer &data) { sendbuf_.push(n); + update_endpoint(&path.path.local); update_remote_addr(&path.path.remote); - auto rv = server_->send_packet(remote_addr_, sendbuf_); + auto rv = server_->send_packet(*endpoint_, remote_addr_, sendbuf_); if (rv != NETWORK_ERR_OK) { return rv; } @@ -1711,8 +1746,9 @@ int Handler::start_closing_period(int liberr) { err_code = ngtcp2_err_infer_quic_transport_error_code(liberr); } + PathStorage path; auto n = ngtcp2_conn_write_connection_close( - conn_, nullptr, conn_closebuf_->wpos(), max_pktlen_, err_code, + conn_, &path.path, conn_closebuf_->wpos(), max_pktlen_, err_code, util::timestamp(loop_)); if (n < 0) { std::cerr << "ngtcp2_conn_write_connection_close: " << ngtcp2_strerror(n) @@ -1722,6 +1758,9 @@ int Handler::start_closing_period(int liberr) { conn_closebuf_->push(n); + update_endpoint(&path.path.local); + update_remote_addr(&path.path.remote); + return 0; } @@ -1754,7 +1793,7 @@ int Handler::send_conn_close() { sendbuf_.push(conn_closebuf_->size()); } - return server_->send_packet(remote_addr_, sendbuf_); + return server_->send_packet(*endpoint_, remote_addr_, sendbuf_); } void Handler::schedule_retransmit() { @@ -1886,6 +1925,8 @@ int Handler::update_key() { const ngtcp2_cid *Handler::scid() const { return &scid_; } +const ngtcp2_cid *Handler::pscid() const { return &pscid_; } + const ngtcp2_cid *Handler::rcid() const { return &rcid_; } Server *Handler::server() const { return server_; } @@ -1969,22 +2010,18 @@ namespace { void swritecb(struct ev_loop *loop, ev_io *w, int revents) { ev_io_stop(loop, w); - auto s = static_cast<Server *>(w->data); - - auto rv = s->on_write(); - if (rv != 0) { - if (rv == NETWORK_ERR_SEND_NON_FATAL) { - s->start_wev(); - } - } + auto ep = static_cast<Endpoint *>(w->data); + // TODO At the moment, this triggers writes to the all endpoints, + // which is not ideal. + ep->server->on_write(); } } // namespace namespace { void sreadcb(struct ev_loop *loop, ev_io *w, int revents) { - auto s = static_cast<Server *>(w->data); + auto ep = static_cast<Endpoint *>(w->data); - s->on_read(); + ep->server->on_read(*ep); } } // namespace @@ -1995,11 +2032,7 @@ void siginthandler(struct ev_loop *loop, ev_signal *watcher, int revents) { } // namespace Server::Server(struct ev_loop *loop, SSL_CTX *ssl_ctx) - : loop_(loop), ssl_ctx_(ssl_ctx), token_crypto_ctx_{}, fd_(-1) { - ev_io_init(&wev_, swritecb, 0, EV_WRITE); - ev_io_init(&rev_, sreadcb, 0, EV_READ); - wev_.data = this; - rev_.data = this; + : loop_(loop), ssl_ctx_(ssl_ctx), token_crypto_ctx_{} { ev_signal_init(&sigintev_, siginthandler, SIGINT); crypto::aead_aes_128_gcm(token_crypto_ctx_); @@ -2020,7 +2053,9 @@ void Server::disconnect() { disconnect(0); } void Server::disconnect(int liberr) { config.tx_loss_prob = 0; - ev_io_stop(loop_, &rev_); + for (auto &ep : endpoints_) { + ev_io_stop(loop_, &ep.rev); + } ev_signal_stop(loop_, &sigintev_); @@ -2035,22 +2070,173 @@ void Server::disconnect(int liberr) { } void Server::close() { - ev_io_stop(loop_, &wev_); + for (auto &ep : endpoints_) { + ev_io_stop(loop_, &ep.wev); + ::close(ep.fd); + } + + endpoints_.clear(); +} + +namespace { +int create_sock(Address &local_addr, const char *addr, const char *port, + int family) { + addrinfo hints{}; + addrinfo *res, *rp; + int rv; + int val = 1; + + hints.ai_family = family; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_flags = AI_PASSIVE; + + if (strcmp("addr", "*") == 0) { + addr = nullptr; + } - if (fd_ != -1) { - ::close(fd_); - fd_ = -1; + rv = getaddrinfo(addr, port, &hints, &res); + if (rv != 0) { + std::cerr << "getaddrinfo: " << gai_strerror(rv) << std::endl; + return -1; } + + auto res_d = defer(freeaddrinfo, res); + + int fd = -1; + + for (rp = res; rp; rp = rp->ai_next) { + fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (fd == -1) { + continue; + } + + if (rp->ai_family == AF_INET6) { + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &val, + static_cast<socklen_t>(sizeof(val))) == -1) { + close(fd); + continue; + } + } + + if (bind(fd, rp->ai_addr, rp->ai_addrlen) != -1) { + break; + } + + close(fd); + } + + if (!rp) { + std::cerr << "Could not bind" << std::endl; + return -1; + } + + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, + static_cast<socklen_t>(sizeof(val))) == -1) { + return -1; + } + + socklen_t len = sizeof(local_addr.su.storage); + rv = getsockname(fd, &local_addr.su.sa, &len); + if (rv == -1) { + std::cerr << "getsockname: " << strerror(errno) << std::endl; + return -1; + } + local_addr.len = len; + + return fd; } -int Server::init(int fd, const Address &local_addr) { - local_addr_ = local_addr; - fd_ = fd; +} // namespace - ev_io_set(&wev_, fd_, EV_WRITE); - ev_io_set(&rev_, fd_, EV_READ); +namespace { +int add_endpoint(std::vector<Endpoint> &endpoints, const char *addr, + const char *port, int af) { + Address dest; + auto fd = create_sock(dest, addr, port, af); + if (fd == -1) { + return -1; + } - ev_io_start(loop_, &rev_); + endpoints.emplace_back(); + auto &ep = endpoints.back(); + ep.addr = dest; + ep.fd = fd; + ev_io_init(&ep.wev, swritecb, 0, EV_WRITE); + ev_io_init(&ep.rev, sreadcb, 0, EV_READ); + + return 0; +} +} // namespace + +namespace { +int add_endpoint(std::vector<Endpoint> &endpoints, const Address &addr) { + auto fd = socket(addr.su.sa.sa_family, SOCK_DGRAM, 0); + if (fd == -1) { + std::cerr << "socket: " << strerror(errno) << std::endl; + return -1; + } + + int val = 1; + if (addr.su.sa.sa_family == AF_INET6 && + setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &val, + static_cast<socklen_t>(sizeof(val)))) { + std::cerr << "setsockopt: " << strerror(errno) << std::endl; + close(fd); + return -1; + } + + if (bind(fd, &addr.su.sa, addr.len) == -1) { + std::cerr << "bind: " << strerror(errno) << std::endl; + close(fd); + return -1; + } + + endpoints.emplace_back(Endpoint{}); + auto &ep = endpoints.back(); + ep.addr = addr; + ep.fd = fd; + ev_io_init(&ep.wev, swritecb, 0, EV_WRITE); + ev_io_init(&ep.rev, sreadcb, 0, EV_READ); + + return 0; +} +} // namespace + +int Server::init(const char *addr, const char *port) { + endpoints_.reserve(4); + + auto ready = false; + if (!util::numeric_host(addr, AF_INET6) && + add_endpoint(endpoints_, addr, port, AF_INET) == 0) { + ready = true; + } + if (!util::numeric_host(addr, AF_INET) && + add_endpoint(endpoints_, addr, port, AF_INET6) == 0) { + ready = true; + } + if (!ready) { + return -1; + } + + if (config.preferred_ipv4_addr.len && + add_endpoint(endpoints_, config.preferred_ipv4_addr) != 0) { + return -1; + } + if (config.preferred_ipv6_addr.len && + add_endpoint(endpoints_, config.preferred_ipv6_addr) != 0) { + return -1; + } + + for (auto &ep : endpoints_) { + ep.server = this; + ep.wev.data = &ep; + ep.rev.data = &ep; + + ev_io_set(&ep.wev, ep.fd, EV_WRITE); + ev_io_set(&ep.rev, ep.fd, EV_READ); + + ev_io_start(loop_, &ep.rev); + } ev_signal_start(loop_, &sigintev_); @@ -2075,7 +2261,7 @@ int Server::on_write() { return NETWORK_ERR_OK; } -int Server::on_read() { +int Server::on_read(Endpoint &ep) { sockaddr_union su; socklen_t addrlen; std::array<uint8_t, 64_k> buf; @@ -2085,7 +2271,7 @@ int Server::on_read() { while (true) { addrlen = sizeof(su); auto nread = - recvfrom(fd_, buf.data(), buf.size(), MSG_DONTWAIT, &su.sa, &addrlen); + recvfrom(ep.fd, buf.data(), buf.size(), MSG_DONTWAIT, &su.sa, &addrlen); if (nread == -1) { if (!(errno == EAGAIN || errno == ENOTCONN)) { std::cerr << "recvfrom: " << strerror(errno) << std::endl; @@ -2142,7 +2328,7 @@ int Server::on_read() { std::cerr << "Unsupported version: Send Version Negotiation" << std::endl; } - send_version_negotiation(&hd, &su.sa, addrlen); + send_version_negotiation(&hd, ep, &su.sa, addrlen); return 0; } @@ -2152,16 +2338,16 @@ int Server::on_read() { std::cerr << "Perform stateless address validation" << std::endl; if (hd.tokenlen == 0 || verify_token(&ocid, &hd, &su.sa, addrlen) != 0) { - send_retry(&hd, &su.sa, addrlen); + send_retry(&hd, ep, &su.sa, addrlen); return 0; } pocid = &ocid; } auto h = std::make_unique<Handler>(loop_, ssl_ctx_, this, &hd.dcid); - h->init(fd_, &su.sa, addrlen, &hd.scid, pocid, hd.version); + h->init(ep, &su.sa, addrlen, &hd.scid, pocid, hd.version); - if (h->on_read(&su.sa, addrlen, buf.data(), nread) != 0) { + if (h->on_read(ep, &su.sa, addrlen, buf.data(), nread) != 0) { return 0; } rv = h->on_write(); @@ -2169,7 +2355,6 @@ int Server::on_read() { case 0: break; case NETWORK_ERR_SEND_NON_FATAL: - start_wev(); break; default: return 0; @@ -2177,8 +2362,15 @@ int Server::on_read() { auto scid = h->scid(); auto scid_key = util::make_cid_key(scid); - handlers_.emplace(scid_key, std::move(h)); ctos_.emplace(dcid_key, scid_key); + + auto pscid = h->pscid(); + if (pscid->datalen) { + auto pscid_key = util::make_cid_key(pscid); + ctos_.emplace(pscid_key, scid_key); + } + + handlers_.emplace(scid_key, std::move(h)); return 0; } if (!config.quiet) { @@ -2207,7 +2399,7 @@ int Server::on_read() { return 0; } - rv = h->on_read(&su.sa, addrlen, buf.data(), nread); + rv = h->on_read(ep, &su.sa, addrlen, buf.data(), nread); if (rv != 0) { if (rv != NETWORK_ERR_CLOSE_WAIT) { remove(handler_it); @@ -2221,7 +2413,6 @@ int Server::on_read() { case NETWORK_ERR_CLOSE_WAIT: break; case NETWORK_ERR_SEND_NON_FATAL: - start_wev(); break; default: remove(handler_it); @@ -2253,7 +2444,7 @@ uint32_t generate_reserved_version(const sockaddr *sa, socklen_t salen, } } // namespace -int Server::send_version_negotiation(const ngtcp2_pkt_hd *chd, +int Server::send_version_negotiation(const ngtcp2_pkt_hd *chd, Endpoint &ep, const sockaddr *sa, socklen_t salen) { Buffer buf{NGTCP2_MAX_PKTLEN_IPV4}; std::array<uint32_t, 2> sv; @@ -2278,15 +2469,15 @@ int Server::send_version_negotiation(const ngtcp2_pkt_hd *chd, remote_addr.len = salen; memcpy(&remote_addr.su.sa, sa, salen); - if (send_packet(remote_addr, buf) != NETWORK_ERR_OK) { + if (send_packet(ep, remote_addr, buf) != NETWORK_ERR_OK) { return -1; } return 0; } -int Server::send_retry(const ngtcp2_pkt_hd *chd, const sockaddr *sa, - socklen_t salen) { +int Server::send_retry(const ngtcp2_pkt_hd *chd, Endpoint &ep, + const sockaddr *sa, socklen_t salen) { std::array<char, NI_MAXHOST> host; std::array<char, NI_MAXSERV> port; int rv; @@ -2345,7 +2536,7 @@ int Server::send_retry(const ngtcp2_pkt_hd *chd, const sockaddr *sa, remote_addr.len = salen; memcpy(&remote_addr.su.sa, sa, salen); - if (send_packet(remote_addr, buf) != NETWORK_ERR_OK) { + if (send_packet(ep, remote_addr, buf) != NETWORK_ERR_OK) { return -1; } @@ -2535,7 +2726,7 @@ int Server::verify_token(ngtcp2_cid *ocid, const ngtcp2_pkt_hd *hd, return 0; } -int Server::send_packet(Address &remote_addr, Buffer &buf) { +int Server::send_packet(Endpoint &ep, const Address &remote_addr, Buffer &buf) { if (debug::packet_lost(config.tx_loss_prob)) { if (!config.quiet) { std::cerr << "** Simulated outgoing packet loss **" << std::endl; @@ -2548,7 +2739,7 @@ int Server::send_packet(Address &remote_addr, Buffer &buf) { ssize_t nwrite = 0; do { - nwrite = sendto(fd_, buf.rpos(), buf.size(), 0, &remote_addr.su.sa, + nwrite = sendto(ep.fd, buf.rpos(), buf.size(), 0, &remote_addr.su.sa, remote_addr.len); } while ((nwrite == -1) && (errno == EINTR) && (eintr_retries-- > 0)); @@ -2557,6 +2748,7 @@ int Server::send_packet(Address &remote_addr, Buffer &buf) { case EAGAIN: case EINTR: case 0: + ev_io_start(loop_, &ep.wev); return NETWORK_ERR_SEND_NON_FATAL; default: std::cerr << "sendto: " << strerror(errno) << std::endl; @@ -2589,6 +2781,7 @@ void Server::dissociate_cid(const ngtcp2_cid *cid) { void Server::remove(const Handler *h) { ctos_.erase(util::make_cid_key(h->rcid())); + ctos_.erase(util::make_cid_key(h->pscid())); auto conn = h->conn(); std::vector<ngtcp2_cid> cids(ngtcp2_conn_get_num_scid(conn)); @@ -2607,10 +2800,6 @@ std::map<std::string, std::unique_ptr<Handler>>::const_iterator Server::remove( return handlers_.erase(it); } -void Server::start_wev() { ev_io_start(loop_, &wev_); } - -const Address &Server::get_local_addr() const { return local_addr_; } - namespace { int alpn_select_proto_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, @@ -2814,106 +3003,67 @@ fail: } // namespace namespace { -int create_sock(Address &local_addr, const char *addr, const char *port, - int family) { - addrinfo hints{}; - addrinfo *res, *rp; - int rv; - int val = 1; - - hints.ai_family = family; - hints.ai_socktype = SOCK_DGRAM; - hints.ai_flags = AI_PASSIVE; - - if (strcmp("addr", "*") == 0) { - addr = nullptr; - } +std::ofstream keylog_file; +void keylog_callback(const SSL *ssl, const char *line) { + keylog_file.write(line, strlen(line)); + keylog_file.put('\n'); + keylog_file.flush(); +} +} // namespace - rv = getaddrinfo(addr, port, &hints, &res); - if (rv != 0) { - std::cerr << "getaddrinfo: " << gai_strerror(rv) << std::endl; +namespace { +int parse_host_port(Address &dest, int af, const char *first, + const char *last) { + if (std::distance(first, last) == 0) { return -1; } - auto res_d = defer(freeaddrinfo, res); - - int fd = -1; - - for (rp = res; rp; rp = rp->ai_next) { - fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (fd == -1) { - continue; + const char *host_begin, *host_end, *it; + if (*first == '[') { + host_begin = first + 1; + it = std::find(host_begin, last, ']'); + if (it == last) { + return -1; } - - if (rp->ai_family == AF_INET6) { - if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &val, - static_cast<socklen_t>(sizeof(val))) == -1) { - close(fd); - continue; - } + host_end = it; + ++it; + if (it == last || *it != ':') { + return -1; } - - if (bind(fd, rp->ai_addr, rp->ai_addrlen) != -1) { - break; + } else { + host_begin = first; + it = std::find(host_begin, last, ':'); + if (it == last) { + return -1; } - - close(fd); - } - - if (!rp) { - std::cerr << "Could not bind" << std::endl; - return -1; + host_end = it; } - if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, - static_cast<socklen_t>(sizeof(val))) == -1) { + if (++it == last) { return -1; } + auto svc_begin = it; - socklen_t len = sizeof(local_addr.su.storage); - rv = getsockname(fd, &local_addr.su.sa, &len); - if (rv == -1) { - std::cerr << "getsockname: " << strerror(errno) << std::endl; - return -1; - } - local_addr.len = len; - - return fd; -} - -} // namespace + std::array<char, NI_MAXHOST> host; + *std::copy(host_begin, host_end, std::begin(host)) = '\0'; -namespace { -int serve(Server &s, const char *addr, const char *port, int family) { - Address local_addr; + addrinfo hints{}, *res; + hints.ai_family = af; + hints.ai_socktype = SOCK_DGRAM; - auto fd = create_sock(local_addr, addr, port, family); - if (fd == -1) { + auto rv = getaddrinfo(host.data(), svc_begin, &hints, &res); + if (rv != 0) { + std::cerr << "getaddrinfo: [" << host.data() << "]:" << svc_begin << ": " + << gai_strerror(rv) << std::endl; return -1; } - if (s.init(fd, local_addr) != 0) { - return -1; - } + dest.len = res->ai_addrlen; + memcpy(&dest.su, res->ai_addr, res->ai_addrlen); - return 0; -} -} // namespace - -namespace { -void close(Server &s) { - s.disconnect(); - - s.close(); -} -} // namespace + freeaddrinfo(res); -namespace { -std::ofstream keylog_file; -void keylog_callback(const SSL *ssl, const char *line) { - keylog_file.write(line, strlen(line)); - keylog_file.put('\n'); - keylog_file.flush(); + return 0; } } // namespace @@ -2984,6 +3134,12 @@ Options: << config.timeout << R"( -V, --validate-addr Perform address validation. + --preferred-ipv4-addr=<ADDR>:<PORT> + Specify preferred IPv4 address and port. + --preferred-ipv6-addr=<ADDR>:<PORT> + Specify preferred IPv6 address and port. A numeric IPv6 + address must be enclosed by '[' and ']' (e.g., + [::1]:8443) -h, --help Display this help and exit. )"; } @@ -3005,6 +3161,8 @@ int main(int argc, char **argv) { {"ciphers", required_argument, &flag, 1}, {"groups", required_argument, &flag, 2}, {"timeout", required_argument, &flag, 3}, + {"preferred-ipv4-addr", required_argument, &flag, 4}, + {"preferred-ipv6-addr", required_argument, &flag, 5}, {nullptr, 0, nullptr, 0}}; auto optidx = 0; @@ -3065,6 +3223,24 @@ int main(int argc, char **argv) { // --timeout config.timeout = strtol(optarg, nullptr, 10); break; + case 4: + // --preferred-ipv4-addr + if (parse_host_port(config.preferred_ipv4_addr, AF_INET, optarg, + optarg + strlen(optarg)) != 0) { + std::cerr << "preferred-ipv4-addr: could not use '" << optarg << "'" + << std::endl; + exit(EXIT_FAILURE); + } + break; + case 5: + // --preferred-ipv6-addr + if (parse_host_port(config.preferred_ipv6_addr, AF_INET6, optarg, + optarg + strlen(optarg)) != 0) { + std::cerr << "preferred-ipv6-addr: could not use '" << optarg << "'" + << std::endl; + exit(EXIT_FAILURE); + } + break; } break; default: @@ -3117,30 +3293,15 @@ int main(int argc, char **argv) { } } - auto ready = false; - - Server s4(EV_DEFAULT, ssl_ctx); - if (!util::numeric_host(addr, AF_INET6)) { - if (serve(s4, addr, port, AF_INET) == 0) { - ready = true; - } - } - - Server s6(EV_DEFAULT, ssl_ctx); - if (!util::numeric_host(addr, AF_INET)) { - if (serve(s6, addr, port, AF_INET6) == 0) { - ready = true; - } - } - - if (!ready) { + Server s(EV_DEFAULT, ssl_ctx); + if (s.init(addr, port) != 0) { exit(EXIT_FAILURE); } ev_run(EV_DEFAULT, 0); - close(s6); - close(s4); + s.disconnect(); + s.close(); return EXIT_SUCCESS; } diff --git a/examples/server.h b/examples/server.h index a86798e9..2757ccdc 100644 --- a/examples/server.h +++ b/examples/server.h @@ -47,6 +47,8 @@ using namespace ngtcp2; struct Config { + Address preferred_ipv4_addr; + Address preferred_ipv6_addr; // tx_loss_prob is probability of losing outgoing packet. double tx_loss_prob; // rx_loss_prob is probability of losing incoming packet. @@ -153,24 +155,33 @@ struct Stream { class Server; +// Endpoint is a local endpoint. +struct Endpoint { + Address addr; + ev_io wev; + ev_io rev; + Server *server; + int fd; +}; + class Handler { public: Handler(struct ev_loop *loop, SSL_CTX *ssl_ctx, Server *server, const ngtcp2_cid *rcid); ~Handler(); - int init(int fd, const sockaddr *sa, socklen_t salen, const ngtcp2_cid *dcid, - const ngtcp2_cid *ocid, uint32_t version); + int init(Endpoint &ep, const sockaddr *sa, socklen_t salen, + const ngtcp2_cid *dcid, const ngtcp2_cid *ocid, uint32_t version); int tls_handshake(); int read_tls(); - int on_read(const sockaddr *sa, socklen_t salen, uint8_t *data, + int on_read(Endpoint &ep, const sockaddr *sa, socklen_t salen, uint8_t *data, size_t datalen); int on_write(bool retransmit = false); int on_write_stream(Stream &stream); int write_stream_data(Stream &stream, int fin, Buffer &data); - int feed_data(const sockaddr *sa, socklen_t salen, uint8_t *data, - size_t datalen); + int feed_data(Endpoint &ep, const sockaddr *sa, socklen_t salen, + uint8_t *data, size_t datalen); int do_handshake_read_once(const uint8_t *data, size_t datalen); ssize_t do_handshake_write_once(); int do_handshake(const uint8_t *data, size_t datalen); @@ -214,6 +225,7 @@ public: int recv_stream_data(uint64_t stream_id, uint8_t fin, const uint8_t *data, size_t datalen); const ngtcp2_cid *scid() const; + const ngtcp2_cid *pscid() const; const ngtcp2_cid *rcid() const; uint32_t version() const; void remove_tx_crypto_data(uint64_t offset, size_t datalen); @@ -225,6 +237,7 @@ public: bool draining() const; int handle_error(int liberror); int send_conn_close(); + void update_endpoint(const ngtcp2_addr *addr); void update_remote_addr(const ngtcp2_addr *addr); int send_greeting(); @@ -236,13 +249,13 @@ public: int update_key(); private: + Endpoint *endpoint_; Address remote_addr_; size_t max_pktlen_; struct ev_loop *loop_; SSL_CTX *ssl_ctx_; SSL *ssl_; Server *server_; - int fd_; ev_timer timer_; ev_timer rttimer_; std::vector<uint8_t> chandshake_; @@ -253,6 +266,7 @@ private: size_t shandshake_idx_; ngtcp2_conn *conn_; ngtcp2_cid scid_; + ngtcp2_cid pscid_; ngtcp2_cid rcid_; crypto::Context hs_crypto_ctx_; crypto::Context crypto_ctx_; @@ -287,46 +301,42 @@ public: Server(struct ev_loop *loop, SSL_CTX *ssl_ctx); ~Server(); - int init(int fd, const Address &local_addr); + int init(const char *addr, const char *port); void disconnect(); void disconnect(int liberr); void close(); int on_write(); - int on_read(); - int send_version_negotiation(const ngtcp2_pkt_hd *hd, const sockaddr *sa, - socklen_t salen); - int send_retry(const ngtcp2_pkt_hd *chd, const sockaddr *sa, socklen_t salen); + int on_read(Endpoint &ep); + int send_version_negotiation(const ngtcp2_pkt_hd *hd, Endpoint &ep, + const sockaddr *sa, socklen_t salen); + int send_retry(const ngtcp2_pkt_hd *chd, Endpoint &ep, const sockaddr *sa, + socklen_t salen); int generate_token(uint8_t *token, size_t &tokenlen, const sockaddr *sa, socklen_t salen, const ngtcp2_cid *ocid); int verify_token(ngtcp2_cid *ocid, const ngtcp2_pkt_hd *hd, const sockaddr *sa, socklen_t salen); - int send_packet(Address &remote_addr, Buffer &buf); + int send_packet(Endpoint &ep, const Address &remote_addr, Buffer &buf); void remove(const Handler *h); std::map<std::string, std::unique_ptr<Handler>>::const_iterator remove(std::map<std::string, std::unique_ptr<Handler>>::const_iterator it); - void start_wev(); int derive_token_key(uint8_t *key, size_t &keylen, uint8_t *iv, size_t &ivlen, const uint8_t *rand_data, size_t rand_datalen); int generate_rand_data(uint8_t *buf, size_t len); void associate_cid(const ngtcp2_cid *cid, Handler *h); void dissociate_cid(const ngtcp2_cid *cid); - const Address &get_local_addr() const; private: - Address local_addr_; std::map<std::string, std::unique_ptr<Handler>> handlers_; // ctos_ is a mapping between client's initial destination // connection ID, and server source connection ID. std::map<std::string, std::string> ctos_; struct ev_loop *loop_; + std::vector<Endpoint> endpoints_; SSL_CTX *ssl_ctx_; crypto::Context token_crypto_ctx_; std::array<uint8_t, TOKEN_SECRETLEN> token_secret_; - int fd_; - ev_io wev_; - ev_io rev_; ev_signal sigintev_; }; diff --git a/lib/includes/ngtcp2/ngtcp2.h b/lib/includes/ngtcp2/ngtcp2.h index e5d60e11..823587cd 100644 --- a/lib/includes/ngtcp2/ngtcp2.h +++ b/lib/includes/ngtcp2/ngtcp2.h @@ -740,6 +740,8 @@ typedef struct { /* addr points to the buffer which contains endpoint address. It is opaque to the ngtcp2 library. */ uint8_t *addr; + /* user_data is an arbitrary data and opaque to the library. */ + void *user_data; } ngtcp2_addr; /** @@ -1390,6 +1392,25 @@ typedef int (*ngtcp2_path_validation)(ngtcp2_conn *conn, ngtcp2_path_validation_result res, void *user_data); +/** + * @functypedef + * + * :type:`ngtcp2_select_preferred_addr` is a callback function which + * asks a client application to choose server address from preferred + * addresses |paddr| received from server. An application should + * write preferred address in |dest|. If an application denies the + * preferred addresses, just leave |dest| unmodified (or set dest->len + * to 0) and return 0. + * + * The callback function must return 0 if it succeeds. Returning + * :enum:`NGTCP2_ERR_CALLBACK_FAILURE` makes the library call return + * immediately. + */ +typedef int (*ngtcp2_select_preferred_addr)(ngtcp2_conn *conn, + ngtcp2_addr *dest, + const ngtcp2_preferred_addr *paddr, + void *user_data); + typedef struct { ngtcp2_client_initial client_initial; ngtcp2_recv_client_initial recv_client_initial; @@ -1440,6 +1461,7 @@ typedef struct { ngtcp2_remove_connection_id remove_connection_id; ngtcp2_update_key update_key; ngtcp2_path_validation path_validation; + ngtcp2_select_preferred_addr select_preferred_addr; } ngtcp2_conn_callbacks; /* @@ -2540,7 +2562,8 @@ NGTCP2_EXTERN uint16_t ngtcp2_err_infer_quic_transport_error_code(int liberr); * returns |addr|. */ NGTCP2_EXTERN ngtcp2_addr *ngtcp2_addr_init(ngtcp2_addr *addr, - const void *address, size_t len); + const void *address, size_t len, + void *user_data); /** * @function diff --git a/lib/ngtcp2_addr.c b/lib/ngtcp2_addr.c index c86f8962..9efb6fb2 100644 --- a/lib/ngtcp2_addr.c +++ b/lib/ngtcp2_addr.c @@ -26,9 +26,11 @@ #include <string.h> -ngtcp2_addr *ngtcp2_addr_init(ngtcp2_addr *dest, const void *addr, size_t len) { +ngtcp2_addr *ngtcp2_addr_init(ngtcp2_addr *dest, const void *addr, size_t len, + void *user_data) { dest->len = len; dest->addr = (uint8_t *)addr; + dest->user_data = user_data; return dest; } @@ -37,6 +39,7 @@ void ngtcp2_addr_copy(ngtcp2_addr *dest, const ngtcp2_addr *src) { if (src->len) { memcpy(dest->addr, src->addr, src->len); } + dest->user_data = src->user_data; } void ngtcp2_addr_copy_byte(ngtcp2_addr *dest, const void *addr, diff --git a/lib/ngtcp2_cid.c b/lib/ngtcp2_cid.c index 7491740d..081c2390 100644 --- a/lib/ngtcp2_cid.c +++ b/lib/ngtcp2_cid.c @@ -85,8 +85,8 @@ void ngtcp2_dcid_init(ngtcp2_dcid *dcid, uint64_t seq, const ngtcp2_cid *cid, } else { memset(dcid->token, 0, NGTCP2_STATELESS_RESET_TOKENLEN); } - ngtcp2_addr_init(&dcid->path.local, dcid->local_addrbuf, 0); - ngtcp2_addr_init(&dcid->path.remote, dcid->remote_addrbuf, 0); + ngtcp2_addr_init(&dcid->path.local, dcid->local_addrbuf, 0, NULL); + ngtcp2_addr_init(&dcid->path.remote, dcid->remote_addrbuf, 0, NULL); } void ngtcp2_dcid_copy(ngtcp2_dcid *dest, const ngtcp2_dcid *src) { diff --git a/lib/ngtcp2_conn.c b/lib/ngtcp2_conn.c index 5c088e2c..b32963e5 100644 --- a/lib/ngtcp2_conn.c +++ b/lib/ngtcp2_conn.c @@ -230,6 +230,22 @@ static int conn_call_path_validation(ngtcp2_conn *conn, const ngtcp2_path *path, return 0; } +static int conn_call_select_preferred_addr(ngtcp2_conn *conn, + ngtcp2_addr *dest) { + int rv; + + assert(conn->callbacks.select_preferred_addr); + assert(conn->remote_settings.preferred_address_present); + + rv = conn->callbacks.select_preferred_addr( + conn, dest, &conn->remote_settings.preferred_address, conn->user_data); + if (rv != 0) { + return NGTCP2_ERR_CALLBACK_FAILURE; + } + + return 0; +} + static int crypto_offset_less(const ngtcp2_pq_entry *lhs, const ngtcp2_pq_entry *rhs) { ngtcp2_crypto_frame_chain *lfrc = @@ -455,6 +471,25 @@ static int conn_new(ngtcp2_conn **pconn, const ngtcp2_cid *dcid, goto fail_scids_insert; } + if (server && settings->preferred_address_present) { + scident = ngtcp2_mem_malloc(mem, sizeof(*scident)); + if (scid == NULL) { + rv = NGTCP2_ERR_NOMEM; + goto fail_scident; + } + + ngtcp2_scid_init(scident, 0, &settings->preferred_address.cid, + settings->preferred_address.stateless_reset_token); + + rv = ngtcp2_ksl_insert(&(*pconn)->scids, NULL, + ngtcp2_ksl_key_ptr(&key, &scident->cid), scident); + if (rv != 0) { + goto fail_scids_insert; + } + + (*pconn)->tx_last_cid_seq = 1; + } + ngtcp2_dcid_init(&(*pconn)->dcid, 0, dcid, NULL); ngtcp2_path_copy(&(*pconn)->dcid.path, path); @@ -3018,7 +3053,7 @@ static ssize_t conn_write_path_response(ngtcp2_conn *conn, ngtcp2_path *path, } if (!conn->server) { - /* Client don't expect to response path validation against + /* Client does not expect to respond to path validation against unknown path */ ngtcp2_ringbuf_pop_front(&conn->rx_path_challenge); pcent = NULL; @@ -3063,6 +3098,14 @@ static ssize_t conn_write_path_response(ngtcp2_conn *conn, ngtcp2_path *path, return nwrite; } +/* + * conn_peer_has_unused_cid returns nonzero if the remote endpoint has + * at least one unused connection ID. + */ +static int conn_peer_has_unused_cid(ngtcp2_conn *conn) { + return ngtcp2_ksl_len(&conn->scids) - ngtcp2_pq_size(&conn->used_scids) > 0; +} + ssize_t ngtcp2_conn_write_pkt(ngtcp2_conn *conn, ngtcp2_path *path, uint8_t *dest, size_t destlen, ngtcp2_tstamp ts) { ssize_t nwrite; @@ -3096,7 +3139,7 @@ ssize_t ngtcp2_conn_write_pkt(ngtcp2_conn *conn, ngtcp2_path *path, return nwrite; } - if (conn->pv) { + if (conn->pv && conn_peer_has_unused_cid(conn)) { nwrite = conn_write_path_challenge(conn, path, dest, destlen, ts); if (nwrite || (conn->pv && (conn->pv->flags & NGTCP2_PV_FLAG_BLOCKING))) { return nwrite; @@ -4371,6 +4414,7 @@ static ssize_t conn_recv_handshake_pkt(ngtcp2_conn *conn, const uint8_t *pkt, } else { conn->dcid.cid = hd.scid; } + conn->odcid = hd.scid; } for (; payloadlen;) { @@ -5767,7 +5811,7 @@ static ssize_t conn_recv_pkt(ngtcp2_conn *conn, const ngtcp2_path *path, /* Quoted from spec: if subsequent packets of those types include a different Source Connection ID, they MUST be discarded. */ - if (!ngtcp2_cid_eq(&conn->dcid.cid, &hd.scid)) { + if (!ngtcp2_cid_eq(&conn->odcid, &hd.scid)) { ngtcp2_log_rx_pkt_hd(&conn->log, &hd); ngtcp2_log_info(&conn->log, NGTCP2_LOG_EVENT_PKT, "packet was ignored because of mismatched SCID"); @@ -6325,6 +6369,8 @@ int ngtcp2_conn_read_pkt(ngtcp2_conn *conn, const ngtcp2_path *path, /* client does not expect a packet from unknown path. */ if (!conn->server && !ngtcp2_path_eq(&conn->dcid.path, path) && (!conn->pv || !ngtcp2_path_eq(&conn->pv->dcid.path, path))) { + ngtcp2_log_info(&conn->log, NGTCP2_LOG_EVENT_CON, + "ignore packet from unknown path"); return 0; } @@ -6484,6 +6530,64 @@ int ngtcp2_conn_read_handshake(ngtcp2_conn *conn, const uint8_t *pkt, } } +/* + * conn_select_preferred_addr asks a client application to select a + * server address from preferred addresses received from server. If a + * client chooses the address, path validation will start. + * + * This function returns 0 if it succeeds, or one of the following + * negative error codes: + * + * NGTCP2_ERR_NOMEM + * Out of memory. + * NGTCP2_ERR_CALLBACK_FAILURE + * User-defined callback function failed. + */ +static int conn_select_preferred_addr(ngtcp2_conn *conn) { + uint8_t buf[128]; + ngtcp2_addr addr; + int rv; + ngtcp2_duration timeout; + ngtcp2_pv *pv; + ngtcp2_dcid dcid; + + ngtcp2_addr_init(&addr, buf, 0, NULL); + + rv = conn_call_select_preferred_addr(conn, &addr); + if (rv != 0) { + return rv; + } + + if (addr.len == 0 || ngtcp2_addr_eq(&conn->dcid.path.remote, &addr)) { + return 0; + } + + ngtcp2_dcid_init( + &dcid, 1, &conn->remote_settings.preferred_address.cid, + conn->remote_settings.preferred_address.stateless_reset_token); + + assert(conn->pv == NULL); + + timeout = rcvry_stat_compute_pto(&conn->rcs); + timeout = ngtcp2_max(timeout, 6 * NGTCP2_DEFAULT_INITIAL_RTT); + + rv = ngtcp2_pv_new(&pv, &dcid, timeout, NGTCP2_PV_FLAG_BLOCKING, &conn->log, + conn->mem); + if (rv != 0) { + /* TODO Call ngtcp2_dcid_free here if it is introduced */ + return rv; + } + + conn->pv = pv; + + ngtcp2_addr_copy(&pv->dcid.path.local, &conn->dcid.path.local); + ngtcp2_addr_copy(&pv->dcid.path.remote, &addr); + + conn_reset_congestion_state(conn); + + return 0; +} + /* * conn_write_handshake writes QUIC handshake packets to the buffer * pointed by |dest| of length |destlen|. |early_datalen| specifies @@ -6629,6 +6733,18 @@ static ssize_t conn_write_handshake(ngtcp2_conn *conn, uint8_t *dest, return (ssize_t)rv; } + if (conn->remote_settings.preferred_address_present) { + /* TODO Starting path validation against preferred address must + be done after dropping Handshake key which is impossible at + draft-18. */ + /* TODO And client has to send NEW_CONNECTION_ID before starting + path validation */ + rv = conn_select_preferred_addr(conn); + if (rv != 0) { + return (ssize_t)rv; + } + } + return res; case NGTCP2_CS_SERVER_INITIAL: nwrite = conn_write_server_handshake(conn, dest, destlen, ts); @@ -7609,7 +7725,7 @@ ssize_t ngtcp2_conn_writev_stream(ngtcp2_conn *conn, ngtcp2_path *path, return nwrite; } - if (conn->pv) { + if (conn->pv && conn_peer_has_unused_cid(conn)) { nwrite = conn_write_path_challenge(conn, path, dest, destlen, ts); if (nwrite || (conn->pv && (conn->pv->flags & NGTCP2_PV_FLAG_BLOCKING))) { return nwrite; diff --git a/lib/ngtcp2_conn.h b/lib/ngtcp2_conn.h index ed5bd272..29d7d2c9 100644 --- a/lib/ngtcp2_conn.h +++ b/lib/ngtcp2_conn.h @@ -240,6 +240,10 @@ struct ngtcp2_conn { /* oscid is the source connection ID initially used by the local endpoint. */ ngtcp2_cid oscid; + /* odcid is the destination connection ID initially negotiated + during handshake. It is used to receive late handshake packets + after handshake completion. */ + ngtcp2_cid odcid; /* dcid is the destination connection ID. */ ngtcp2_dcid dcid; /* bound_dcids is a set of destination connection ID which is bound diff --git a/lib/ngtcp2_path.c b/lib/ngtcp2_path.c index b7b52040..a3964de8 100644 --- a/lib/ngtcp2_path.c +++ b/lib/ngtcp2_path.c @@ -47,14 +47,14 @@ int ngtcp2_path_eq(const ngtcp2_path *a, const ngtcp2_path *b) { void ngtcp2_path_storage_init(ngtcp2_path_storage *ps, const void *local_addr, size_t local_addrlen, const void *remote_addr, size_t remote_addrlen) { - ngtcp2_addr_init(&ps->path.local, ps->local_addrbuf, 0); - ngtcp2_addr_init(&ps->path.remote, ps->remote_addrbuf, 0); + ngtcp2_addr_init(&ps->path.local, ps->local_addrbuf, 0, NULL); + ngtcp2_addr_init(&ps->path.remote, ps->remote_addrbuf, 0, NULL); ngtcp2_addr_copy_byte(&ps->path.local, local_addr, local_addrlen); ngtcp2_addr_copy_byte(&ps->path.remote, remote_addr, remote_addrlen); } void ngtcp2_path_storage_zero(ngtcp2_path_storage *ps) { - ngtcp2_addr_init(&ps->path.local, ps->local_addrbuf, 0); - ngtcp2_addr_init(&ps->path.remote, ps->remote_addrbuf, 0); + ngtcp2_addr_init(&ps->path.local, ps->local_addrbuf, 0, NULL); + ngtcp2_addr_init(&ps->path.remote, ps->remote_addrbuf, 0, NULL); } diff --git a/lib/ngtcp2_pv.h b/lib/ngtcp2_pv.h index 34d527dd..974731c5 100644 --- a/lib/ngtcp2_pv.h +++ b/lib/ngtcp2_pv.h @@ -70,9 +70,6 @@ typedef enum { validation against the old path should be done after successful path validation. */ NGTCP2_PV_FLAG_VERIFY_OLD_PATH_ON_SUCCESS = 0x08, - /* NGTCP2_PV_FLAG_INVOKE_CALLBACK indicates that callback must be - called after path validation finishes. */ - NGTCP2_PV_FLAG_INVOKE_CALLBACK = 0x10, } ngtcp2_pv_flag; struct ngtcp2_pv; diff --git a/tests/ngtcp2_conn_test.c b/tests/ngtcp2_conn_test.c index b963be72..95843bef 100644 --- a/tests/ngtcp2_conn_test.c +++ b/tests/ngtcp2_conn_test.c @@ -127,7 +127,8 @@ static uint8_t null_iv[16]; static uint8_t null_pn[16]; static uint8_t null_data[4096]; static ngtcp2_path null_path = {}; -static ngtcp2_path new_path = {{1, (uint8_t *)"1"}, {1, (uint8_t *)"2"}}; +static ngtcp2_path new_path = {{1, (uint8_t *)"1", NULL}, + {1, (uint8_t *)"2", NULL}}; static ngtcp2_vec *null_datav(ngtcp2_vec *datav, size_t len) { datav->base = null_data; @@ -363,6 +364,7 @@ static void setup_default_server(ngtcp2_conn **pconn) { (*pconn)->max_local_stream_id_uni = ngtcp2_nth_server_uni_id((*pconn)->remote_settings.max_streams_uni); (*pconn)->max_tx_offset = (*pconn)->remote_settings.max_data; + (*pconn)->odcid = dcid; } static void setup_default_client(ngtcp2_conn **pconn) { @@ -412,6 +414,7 @@ static void setup_default_client(ngtcp2_conn **pconn) { (*pconn)->max_local_stream_id_uni = ngtcp2_nth_client_uni_id((*pconn)->remote_settings.max_streams_uni); (*pconn)->max_tx_offset = (*pconn)->remote_settings.max_data; + (*pconn)->odcid = dcid; } static void setup_handshake_server(ngtcp2_conn **pconn) { diff --git a/tests/ngtcp2_pv_test.c b/tests/ngtcp2_pv_test.c index 0193bbe9..e0140bc3 100644 --- a/tests/ngtcp2_pv_test.c +++ b/tests/ngtcp2_pv_test.c @@ -91,8 +91,8 @@ void test_ngtcp2_pv_validate(void) { uint8_t data[8]; ngtcp2_duration timeout = 100ULL * NGTCP2_SECONDS; ngtcp2_tstamp t = 1; - ngtcp2_path path = {{1, (uint8_t *)"1"}, {1, (uint8_t *)"2"}}; - ngtcp2_path alt_path = {{1, (uint8_t *)"3"}, {1, (uint8_t *)"4"}}; + ngtcp2_path path = {{1, (uint8_t *)"1", NULL}, {1, (uint8_t *)"2", NULL}}; + ngtcp2_path alt_path = {{1, (uint8_t *)"3", NULL}, {1, (uint8_t *)"4", NULL}}; dcid_init(&cid); ngtcp2_dcid_init(&dcid, 1000000007, &cid, token); -- GitLab