From 8cd22a8d1406c1c3917e3d386d9285c861f249a7 Mon Sep 17 00:00:00 2001
From: Tatsuhiro Tsujikawa <tatsuhiro.t@gmail.com>
Date: Fri, 14 Aug 2020 18:27:30 +0900
Subject: [PATCH] Cache fd

---
 examples/examplestest.cc |   4 +-
 examples/server.cc       | 128 ++++++++++++++++++++-------------------
 examples/server.h        |  11 +---
 examples/util.cc         |  90 +++++++++++++++++++++++++++
 examples/util.h          |   5 ++
 examples/util_test.cc    |  15 +++++
 examples/util_test.h     |   1 +
 7 files changed, 184 insertions(+), 70 deletions(-)

diff --git a/examples/examplestest.cc b/examples/examplestest.cc
index 53d5b845..6487f819 100644
--- a/examples/examplestest.cc
+++ b/examples/examplestest.cc
@@ -63,7 +63,9 @@ int main(int argc, char *argv[]) {
       !CU_add_test(pSuite, "util_parse_uint_iec",
                    ngtcp2::test_util_parse_uint_iec) ||
       !CU_add_test(pSuite, "util_parse_duration",
-                   ngtcp2::test_util_parse_duration)) {
+                   ngtcp2::test_util_parse_duration) ||
+      !CU_add_test(pSuite, "util_normalize_path",
+                   ngtcp2::test_util_normalize_path)) {
     CU_cleanup_registry();
     return CU_get_error();
   }
diff --git a/examples/server.cc b/examples/server.cc
index 55059079..c20617c8 100644
--- a/examples/server.cc
+++ b/examples/server.cc
@@ -171,22 +171,11 @@ int Handler::on_key(ngtcp2_crypto_level level, const uint8_t *rx_secret,
 Stream::Stream(int64_t stream_id, Handler *handler)
     : stream_id(stream_id),
       handler(handler),
-      fd(-1),
       data(nullptr),
       datalen(0),
       dynresp(false),
       dyndataleft(0),
-      dynbuflen(0),
-      mmapped(false) {}
-
-Stream::~Stream() {
-  if (mmapped) {
-    munmap(data, datalen);
-  }
-  if (fd != -1) {
-    close(fd);
-  }
-}
+      dynbuflen(0) {}
 
 namespace {
 constexpr char NGTCP2_SERVER[] = "nghttp3/ngtcp2 server";
@@ -312,45 +301,67 @@ Request request_path(const std::string_view &uri, bool is_connect) {
 
 namespace {
 std::string resolve_path(const std::string &req_path) {
-  auto raw_path = config.htdocs + req_path;
-  std::array<char, PATH_MAX> buf;
-  auto p = realpath(raw_path.c_str(), buf.data());
-  if (p == nullptr) {
-    return "";
-  }
-  auto path = std::string(p);
-
-  if (path.size() < config.htdocs.size() ||
-      !std::equal(std::begin(config.htdocs), std::end(config.htdocs),
-                  std::begin(path))) {
-    return "";
-  }
-  return path;
+  auto path = util::normalize_path(req_path);
+  return config.htdocs + path;
 }
 } // namespace
 
-int Stream::open_file(const std::string &path) {
-  fd = open(path.c_str(), O_RDONLY);
-  if (fd == -1) {
-    return -1;
+enum FileEntryFlag {
+  FILE_ENTRY_TYPE_DIR = 0x1,
+};
+
+struct FileEntry {
+  uint64_t len;
+  void *map;
+  int fd;
+  uint8_t flags;
+};
+
+namespace {
+std::unordered_map<std::string, FileEntry> file_cache;
+} // namespace
+
+std::pair<FileEntry, int> Stream::open_file(const std::string &path) {
+  auto it = file_cache.find(path);
+  if (it != std::end(file_cache)) {
+    return {(*it).second, 0};
   }
 
-  return 0;
-}
+  auto fd = open(path.c_str(), O_RDONLY);
+  if (fd == -1) {
+    return {{}, -1};
+  }
 
-int Stream::map_file(size_t len) {
-  if (len == 0) {
-    return 0;
+  struct stat st {};
+  if (fstat(fd, &st) != 0) {
+    close(fd);
+    return {{}, -1};
   }
-  data =
-      static_cast<uint8_t *>(mmap(nullptr, len, PROT_READ, MAP_SHARED, fd, 0));
-  if (data == MAP_FAILED) {
-    std::cerr << "mmap: " << strerror(errno) << std::endl;
-    return -1;
+
+  FileEntry fe{};
+  if (st.st_mode & S_IFDIR) {
+    fe.flags |= FILE_ENTRY_TYPE_DIR;
+    fe.fd = -1;
+    close(fd);
+  } else {
+    fe.fd = fd;
+    fe.len = st.st_size;
+    fe.map = mmap(nullptr, fe.len, PROT_READ, MAP_SHARED, fd, 0);
+    if (fe.map == MAP_FAILED) {
+      std::cerr << "mmap: " << strerror(errno) << std::endl;
+      close(fd);
+      return {{}, -1};
+    }
   }
-  datalen = len;
-  mmapped = true;
-  return 0;
+
+  file_cache.emplace(path, fe);
+
+  return {std::move(fe), 0};
+}
+
+void Stream::map_file(const FileEntry &fe) {
+  data = static_cast<uint8_t *>(fe.map);
+  datalen = fe.len;
 }
 
 int64_t Stream::find_dyn_length(const std::string_view &path) {
@@ -524,33 +535,28 @@ int Stream::start_response(nghttp3_conn *httpconn) {
 
   if (dyn_len == -1) {
     auto path = resolve_path(req.path);
-    if (path.empty() || open_file(path) != 0) {
+    if (path.empty()) {
       send_status_response(httpconn, 404);
       return 0;
     }
-
-    struct stat st {};
-
-    if (fstat(fd, &st) == 0) {
-      if (st.st_mode & S_IFDIR) {
-        send_redirect_response(httpconn, 308,
-                               path.substr(config.htdocs.size() - 1) + '/');
-        return 0;
-      }
-      content_length = st.st_size;
-    } else {
+    auto [fe, rv] = open_file(path);
+    if (rv != 0) {
       send_status_response(httpconn, 404);
       return 0;
     }
 
-    if (method == "HEAD") {
-      close(fd);
-      fd = -1;
-    } else if (map_file(content_length) != 0) {
-      send_status_response(httpconn, 500);
+    if (fe.flags & FILE_ENTRY_TYPE_DIR) {
+      send_redirect_response(httpconn, 308,
+                             path.substr(config.htdocs.size() - 1) + '/');
       return 0;
     }
 
+    content_length = fe.len;
+
+    if (method != "HEAD") {
+      map_file(fe);
+    }
+
     dr.read_data = read_data;
 
     auto ext = std::end(req.path) - 1;
@@ -1281,7 +1287,7 @@ int http_acked_stream_data(nghttp3_conn *conn, int64_t stream_id,
 void Handler::http_acked_stream_data(Stream *stream, size_t datalen) {
   stream->http_acked_stream_data(datalen);
 
-  if (stream->fd == -1 && stream->dynbuflen < MAX_DYNBUFLEN - 16384) {
+  if (stream->dynresp && stream->dynbuflen < MAX_DYNBUFLEN - 16384) {
     if (auto rv = nghttp3_conn_resume_stream(httpconn_, stream->stream_id);
         rv != 0) {
       // TODO Handle error
diff --git a/examples/server.h b/examples/server.h
index 845abbdd..fb0102c0 100644
--- a/examples/server.h
+++ b/examples/server.h
@@ -162,14 +162,14 @@ struct HTTPHeader {
 };
 
 class Handler;
+struct FileEntry;
 
 struct Stream {
   Stream(int64_t stream_id, Handler *handler);
-  ~Stream();
 
   int start_response(nghttp3_conn *conn);
-  int open_file(const std::string &path);
-  int map_file(size_t len);
+  std::pair<FileEntry, int> open_file(const std::string &path);
+  void map_file(const FileEntry &fe);
   int send_status_response(nghttp3_conn *conn, unsigned int status_code,
                            const std::vector<HTTPHeader> &extra_headers = {});
   int send_redirect_response(nghttp3_conn *conn, unsigned int status_code,
@@ -183,9 +183,6 @@ struct Stream {
   std::string uri;
   std::string method;
   std::string authority;
-  // fd is a file descriptor to read file to send its content to a
-  // client.
-  int fd;
   std::string status_resp_body;
   // data is a pointer to the memory which maps file denoted by fd.
   uint8_t *data;
@@ -197,8 +194,6 @@ struct Stream {
   uint64_t dyndataleft;
   // dynbuflen is the number of bytes in-flight.
   uint64_t dynbuflen;
-  // mmapped is true if data points to the memory assigned by mmap.
-  bool mmapped;
 };
 
 class Server;
diff --git a/examples/util.cc b/examples/util.cc
index b1b92725..95717359 100644
--- a/examples/util.cc
+++ b/examples/util.cc
@@ -516,6 +516,96 @@ int generate_secret(uint8_t *secret, size_t secretlen) {
   return 0;
 }
 
+namespace {
+template <typename InputIt> InputIt eat_file(InputIt first, InputIt last) {
+  if (first == last) {
+    *first++ = '/';
+    return first;
+  }
+
+  if (*(last - 1) == '/') {
+    return last;
+  }
+
+  auto p = last;
+  for (; p != first && *(p - 1) != '/'; --p)
+    ;
+  if (p == first) {
+    // this should not happened in normal case, where we expect path
+    // starts with '/'
+    *first++ = '/';
+    return first;
+  }
+
+  return p;
+}
+} // namespace
+
+namespace {
+template <typename InputIt> InputIt eat_dir(InputIt first, InputIt last) {
+  auto p = eat_file(first, last);
+
+  --p;
+
+  assert(*p == '/');
+
+  return eat_file(first, p);
+}
+} // namespace
+
+std::string normalize_path(const std::string &path) {
+  assert(path.size() <= 1024);
+  assert(path.size() > 0);
+  assert(path[0] == '/');
+
+  std::array<char, 1024> res;
+  auto p = res.data();
+
+  auto first = std::begin(path);
+  auto last = std::end(path);
+
+  *p++ = '/';
+  ++first;
+  for (; first != last && *first == '/'; ++first)
+    ;
+
+  for (; first != last;) {
+    if (*first == '.') {
+      if (first + 1 == last) {
+        break;
+      }
+      if (*(first + 1) == '/') {
+        first += 2;
+        continue;
+      }
+      if (*(first + 1) == '.') {
+        if (first + 2 == last) {
+          p = eat_dir(res.data(), p);
+          break;
+        }
+        if (*(first + 2) == '/') {
+          p = eat_dir(res.data(), p);
+          first += 3;
+          continue;
+        }
+      }
+    }
+    if (*(p - 1) != '/') {
+      p = eat_file(res.data(), p);
+    }
+    auto slash = std::find(first, last, '/');
+    if (slash == last) {
+      p = std::copy(first, last, p);
+      break;
+    }
+    p = std::copy(first, slash + 1, p);
+    first = slash + 1;
+    for (; first != last && *first == '/'; ++first)
+      ;
+  }
+  return std::string{res.data(), p};
+}
+
 } // namespace util
 
 std::ostream &operator<<(std::ostream &os, const ngtcp2_cid &cid) {
diff --git a/examples/util.h b/examples/util.h
index 42b28062..553e6148 100644
--- a/examples/util.h
+++ b/examples/util.h
@@ -276,6 +276,11 @@ std::pair<uint64_t, int> parse_duration(const std::string_view &s);
 // must be 32.
 int generate_secret(uint8_t *secret, size_t secretlen);
 
+// normalize_path removes ".." by consuming a previous path component.
+// It also removes ".".  It assumes that |path| starts with "/".  If
+// it cannot consume a previous path component, it just removes "..".
+std::string normalize_path(const std::string &path);
+
 } // namespace util
 
 std::ostream &operator<<(std::ostream &os, const ngtcp2_cid &cid);
diff --git a/examples/util_test.cc b/examples/util_test.cc
index 36f05c73..9037a514 100644
--- a/examples/util_test.cc
+++ b/examples/util_test.cc
@@ -219,4 +219,19 @@ void test_util_parse_duration() {
   }
 }
 
+void test_util_normalize_path() {
+  CU_ASSERT("/" == util::normalize_path("/"));
+  CU_ASSERT("/" == util::normalize_path("//"));
+  CU_ASSERT("/foo" == util::normalize_path("/foo"));
+  CU_ASSERT("/foo/bar/" == util::normalize_path("/foo/bar/"));
+  CU_ASSERT("/foo/bar/" == util::normalize_path("/foo/abc/../bar/"));
+  CU_ASSERT("/foo/bar/" == util::normalize_path("/../foo/abc/../bar/"));
+  CU_ASSERT("/foo/bar/" ==
+            util::normalize_path("/./foo/././abc///.././bar/./"));
+  CU_ASSERT("/foo/" == util::normalize_path("/foo/."));
+  CU_ASSERT("/foo/bar" == util::normalize_path("/foo/./bar"));
+  CU_ASSERT("/bar" == util::normalize_path("/foo/./../bar"));
+  CU_ASSERT("/bar" == util::normalize_path("/../../bar"));
+}
+
 } // namespace ngtcp2
diff --git a/examples/util_test.h b/examples/util_test.h
index 06a0a84a..376d3a91 100644
--- a/examples/util_test.h
+++ b/examples/util_test.h
@@ -38,6 +38,7 @@ void test_util_format_duration();
 void test_util_parse_uint();
 void test_util_parse_uint_iec();
 void test_util_parse_duration();
+void test_util_normalize_path();
 
 } // namespace ngtcp2
 
-- 
GitLab