diff --git a/httplib.h b/httplib.h index c02ab2b..5eeb0af 100644 --- a/httplib.h +++ b/httplib.h @@ -206,6 +206,8 @@ typedef std::function ContentReceiver; +typedef std::function ContentReader; + typedef std::function Progress; struct Response; @@ -477,6 +479,9 @@ private: class Server { public: typedef std::function Handler; + typedef std::function + HandlerWithContentReader; typedef std::function Logger; Server(); @@ -487,9 +492,11 @@ public: Server &Get(const char *pattern, Handler handler); Server &Post(const char *pattern, Handler handler); - + Server &Post(const char *pattern, HandlerWithContentReader handler); Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); Server &Delete(const char *pattern, Handler handler); Server &Options(const char *pattern, Handler handler); @@ -526,15 +533,20 @@ protected: private: typedef std::vector> Handlers; + typedef std::vector> + HandersForContentReader; socket_t create_server_socket(const char *host, int port, int socket_flags) const; int bind_internal(const char *host, int port, int socket_flags); bool listen_internal(); - bool routing(Request &req, Response &res); + bool routing(Request &req, Response &res, ContentReader content_reader); bool handle_file_request(Request &req, Response &res); bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandersForContentReader &handlers); bool parse_request_line(const char *s, Request &req); bool write_response(Stream &strm, bool last_connection, const Request &req, @@ -542,6 +554,11 @@ private: bool write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type); + bool read_content(Stream &strm, bool last_connection, Request &req, + Response &res); + bool read_content_with_content_receiver(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver reveiver); virtual bool process_and_close_socket(socket_t sock); @@ -551,8 +568,11 @@ private: Handler file_request_handler_; Handlers get_handlers_; Handlers post_handlers_; + HandersForContentReader post_handlers_for_content_reader; Handlers put_handlers_; + HandersForContentReader put_handlers_for_content_reader; Handlers patch_handlers_; + HandersForContentReader patch_handlers_for_content_reader; Handlers delete_handlers_; Handlers options_handlers_; Handler error_handler_; @@ -1487,13 +1507,13 @@ inline bool read_headers(Stream &strm, Headers &headers) { const auto bufsiz = 2048; char buf[bufsiz]; - stream_line_reader reader(strm, buf, bufsiz); + stream_line_reader line_reader(strm, buf, bufsiz); for (;;) { - if (!reader.getline()) { return false; } - if (!strcmp(reader.ptr(), "\r\n")) { break; } + if (!line_reader.getline()) { return false; } + if (!strcmp(line_reader.ptr(), "\r\n")) { break; } std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { + if (std::regex_match(line_reader.ptr(), m, re)) { auto key = std::string(m[1]); auto val = std::string(m[2]); headers.emplace(key, val); @@ -1559,29 +1579,30 @@ inline bool read_content_chunked(Stream &strm, ContentReceiverCore out) { const auto bufsiz = 16; char buf[bufsiz]; - stream_line_reader reader(strm, buf, bufsiz); + stream_line_reader line_reader(strm, buf, bufsiz); - if (!reader.getline()) { return false; } + if (!line_reader.getline()) { return false; } - auto chunk_len = std::stoi(reader.ptr(), 0, 16); + auto chunk_len = std::stoi(line_reader.ptr(), 0, 16); while (chunk_len > 0) { if (!read_content_with_length(strm, chunk_len, nullptr, out)) { return false; } - if (!reader.getline()) { return false; } + if (!line_reader.getline()) { return false; } - if (strcmp(reader.ptr(), "\r\n")) { break; } + if (strcmp(line_reader.ptr(), "\r\n")) { break; } - if (!reader.getline()) { return false; } + if (!line_reader.getline()) { return false; } - chunk_len = std::stoi(reader.ptr(), 0, 16); + chunk_len = std::stoi(line_reader.ptr(), 0, 16); } if (chunk_len == 0) { // Reader terminator after chunks - if (!reader.getline() || strcmp(reader.ptr(), "\r\n")) return false; + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; } return true; @@ -1898,32 +1919,33 @@ inline bool parse_multipart_formdata(const std::string &boundary, inline bool parse_range_header(const std::string &s, Ranges &ranges) { try { - static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + static auto re_first_range = + std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); std::smatch m; if (std::regex_match(s, m, re_first_range)) { auto pos = m.position(1); auto len = m.length(1); - detail::split(&s[pos], &s[pos + len], ',', - [&](const char *b, const char *e) { - static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); - std::cmatch m; - if (std::regex_match(b, e, m, re_another_range)) { - ssize_t first = -1; - if (!m.str(1).empty()) { - first = static_cast(std::stoll(m.str(1))); - } + detail::split( + &s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch m; + if (std::regex_match(b, e, m, re_another_range)) { + ssize_t first = -1; + if (!m.str(1).empty()) { + first = static_cast(std::stoll(m.str(1))); + } - ssize_t last = -1; - if (!m.str(2).empty()) { - last = static_cast(std::stoll(m.str(2))); - } + ssize_t last = -1; + if (!m.str(2).empty()) { + last = static_cast(std::stoll(m.str(2))); + } - if (first != -1 && last != -1 && first > last) { - throw std::runtime_error("invalid range error"); - } - ranges.emplace_back(std::make_pair(first, last)); - } - }); + if (first != -1 && last != -1 && first > last) { + throw std::runtime_error("invalid range error"); + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); return true; } return false; @@ -2344,16 +2366,37 @@ inline Server &Server::Post(const char *pattern, Handler handler) { return *this; } +inline Server &Server::Post(const char *pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + inline Server &Server::Put(const char *pattern, Handler handler) { put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } +inline Server &Server::Put(const char *pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + inline Server &Server::Patch(const char *pattern, Handler handler) { patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; } +inline Server &Server::Patch(const char *pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + inline Server &Server::Delete(const char *pattern, Handler handler) { delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); return *this; @@ -2597,6 +2640,58 @@ Server::write_content_with_provider(Stream &strm, const Request &req, return true; } +inline bool Server::read_content(Stream &strm, bool last_connection, + Request &req, Response &res) { + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + })) { + return write_response(strm, last_connection, req, res); + } + + const auto &content_type = req.get_header_value("Content-Type"); + + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } else if (!content_type.find("multipart/form-data")) { + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary) || + !detail::parse_multipart_formdata(boundary, req.body, req.files)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + } + + return true; +} + +inline bool +Server::read_content_with_content_receiver(Stream &strm, bool last_connection, + Request &req, Response &res, + ContentReceiver receiver) { + size_t offset = 0; + + size_t length = 0; + if (req.get_header_value("Content-Encoding") != "gzip") { + length = get_header_value_uint64(req.headers, "Content-Length", 0); + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), [&](const char *buf, size_t n) { + auto ret = receiver(buf, n, offset, length); + offset += n; + return ret; + })) { + return write_response(strm, last_connection, req, res); + } + + return true; +} + inline bool Server::handle_file_request(Request &req, Response &res) { if (!base_dir_.empty() && detail::is_valid_path(req.path)) { std::string path = base_dir_ + req.path; @@ -2705,9 +2800,33 @@ inline bool Server::listen_internal() { return ret; } -inline bool Server::routing(Request &req, Response &res) { +inline bool Server::routing(Request &req, Response &res, + ContentReader content_reader) { + // File handler if (req.method == "GET" && handle_file_request(req, res)) { return true; } + // Content reader handler + if (req.method == "POST") { + if (dispatch_request_for_content_reader(req, res, content_reader, + post_handlers_for_content_reader)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader(req, res, content_reader, + put_handlers_for_content_reader)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, content_reader, patch_handlers_for_content_reader)) { + return true; + } + } + + // Read content into `req.body` + if (!content_reader(nullptr)) { return false; } + + // Regular handler if (req.method == "GET" || req.method == "HEAD") { return dispatch_request(req, res, get_handlers_); } else if (req.method == "POST") { @@ -2740,6 +2859,22 @@ inline bool Server::dispatch_request(Request &req, Response &res, return false; } +inline bool +Server::dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + inline bool Server::process_request(Stream &strm, bool last_connection, bool &connection_close, @@ -2747,10 +2882,10 @@ Server::process_request(Stream &strm, bool last_connection, const auto bufsiz = 2048; char buf[bufsiz]; - detail::stream_line_reader reader(strm, buf, bufsiz); + detail::stream_line_reader line_reader(strm, buf, bufsiz); // Connection has been closed on client - if (!reader.getline()) { return false; } + if (!line_reader.getline()) { return false; } Request req; Response res; @@ -2758,7 +2893,7 @@ Server::process_request(Stream &strm, bool last_connection, res.version = "HTTP/1.1"; // Check if the request URI doesn't exceed the limit - if (reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { Headers dummy; detail::read_headers(strm, dummy); res.status = 414; @@ -2766,7 +2901,7 @@ Server::process_request(Stream &strm, bool last_connection, } // Request line and headers - if (!parse_request_line(reader.ptr(), req) || + if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { res.status = 400; return write_response(strm, last_connection, req, res); @@ -2783,34 +2918,6 @@ Server::process_request(Stream &strm, bool last_connection, req.set_header("REMOTE_ADDR", strm.get_remote_addr()); - // Body - if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || - req.method == "PRI") { - if (!detail::read_content(strm, req, payload_max_length_, res.status, - Progress(), [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { - return false; - } - req.body.append(buf, n); - return true; - })) { - return write_response(strm, last_connection, req, res); - } - - const auto &content_type = req.get_header_value("Content-Type"); - - if (!content_type.find("application/x-www-form-urlencoded")) { - detail::parse_query_text(req.body, req.params); - } else if (!content_type.find("multipart/form-data")) { - std::string boundary; - if (!detail::parse_multipart_boundary(content_type, boundary) || - !detail::parse_multipart_formdata(boundary, req.body, req.files)) { - res.status = 400; - return write_response(strm, last_connection, req, res); - } - } - } - if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); if (!detail::parse_range_header(range_header_value, req.ranges)) { @@ -2820,7 +2927,23 @@ Server::process_request(Stream &strm, bool last_connection, if (setup_request) { setup_request(req); } - if (routing(req, res)) { + // Body + ContentReader content_reader = [&](ContentReceiver receiver) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { + if (receiver) { + return read_content_with_content_receiver(strm, last_connection, req, + res, receiver); + } else { + return read_content(strm, last_connection, req, res); + } + } else if (req.method == "PRI") { + return read_content(strm, last_connection, req, res); + } + return true; + }; + + // Rounting + if (routing(req, res, content_reader)) { if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } } else { if (res.status == -1) { res.status = 404; } @@ -2876,14 +2999,14 @@ inline bool Client::read_response_line(Stream &strm, Response &res) { const auto bufsiz = 2048; char buf[bufsiz]; - detail::stream_line_reader reader(strm, buf, bufsiz); + detail::stream_line_reader line_reader(strm, buf, bufsiz); - if (!reader.getline()) { return false; } + if (!line_reader.getline()) { return false; } const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); std::cmatch m; - if (std::regex_match(reader.ptr(), m, re)) { + if (std::regex_match(line_reader.ptr(), m, re)) { res.version = std::string(m[1]); res.status = std::stoi(std::string(m[2])); } @@ -3138,8 +3261,14 @@ inline bool Client::process_request(Stream &strm, const Request &req, if (req.content_receiver) { auto offset = std::make_shared(); - auto length = get_header_value_uint64(res.headers, "Content-Length", 0); + + size_t length = 0; + if (res.get_header_value("Content-Encoding") != "gzip") { + length = get_header_value_uint64(res.headers, "Content-Length", 0); + } + auto receiver = req.content_receiver; + out = [offset, length, receiver](const char *buf, size_t n) { auto ret = receiver(buf, n, *offset, length); (*offset) += n; diff --git a/test/test.cc b/test/test.cc index f22ee44..83467f0 100644 --- a/test/test.cc +++ b/test/test.cc @@ -744,6 +744,52 @@ protected: EXPECT_EQ(1u, req.get_header_value_count("Content-Length")); EXPECT_EQ("5", req.get_header_value("Content-Length")); }) + .Post("/content_receiver", + [&](const Request & req, Response &res, + const ContentReader &content_reader) { + std::string body; + content_reader([&](const char *data, size_t data_length, + size_t offset, + uint64_t content_length) { + EXPECT_EQ(offset, 0); + if (req.get_header_value("Content-Encoding") == "gzip") { + EXPECT_EQ(content_length, 0); + } else { + EXPECT_EQ(content_length, 7); + } + EXPECT_EQ(data_length, 7); + body.append(data, data_length); + return true; + }); + EXPECT_EQ(body, "content"); + res.set_content(body, "text/plain"); + }) + .Put("/content_receiver", + [&](const Request & /*req*/, Response &res, + const ContentReader &content_reader) { + std::string body; + content_reader([&](const char *data, size_t data_length, + size_t /*offset*/, + uint64_t /*content_length*/) { + body.append(data, data_length); + return true; + }); + EXPECT_EQ(body, "content"); + res.set_content(body, "text/plain"); + }) + .Patch("/content_receiver", + [&](const Request & /*req*/, Response &res, + const ContentReader &content_reader) { + std::string body; + content_reader([&](const char *data, size_t data_length, + size_t /*offset*/, + uint64_t /*content_length*/) { + body.append(data, data_length); + return true; + }); + EXPECT_EQ(body, "content"); + res.set_content(body, "text/plain"); + }) #ifdef CPPHTTPLIB_ZLIB_SUPPORT .Get("/gzip", [&](const Request & /*req*/, Response &res) { @@ -1354,9 +1400,12 @@ TEST_F(ServerTest, Put) { } TEST_F(ServerTest, PutWithContentProvider) { - auto res = cli_.Put("/put", 3, [](size_t /*offset*/, size_t /*length*/, DataSink sink) { - sink("PUT", 3); - }, "text/plain"); + auto res = cli_.Put( + "/put", 3, + [](size_t /*offset*/, size_t /*length*/, DataSink sink) { + sink("PUT", 3); + }, + "text/plain"); ASSERT_TRUE(res != nullptr); EXPECT_EQ(200, res->status); @@ -1365,9 +1414,12 @@ TEST_F(ServerTest, PutWithContentProvider) { #ifdef CPPHTTPLIB_ZLIB_SUPPORT TEST_F(ServerTest, PutWithContentProviderWithGzip) { - auto res = cli_.Put("/put", 3, [](size_t /*offset*/, size_t /*length*/, DataSink sink) { - sink("PUT", 3); - }, "text/plain", true); + auto res = cli_.Put( + "/put", 3, + [](size_t /*offset*/, size_t /*length*/, DataSink sink) { + sink("PUT", 3); + }, + "text/plain", true); ASSERT_TRUE(res != nullptr); EXPECT_EQ(200, res->status); @@ -1417,6 +1469,34 @@ TEST_F(ServerTest, NoMultipleHeaders) { EXPECT_EQ(200, res->status); } +TEST_F(ServerTest, PostContentReceiver) { + auto res = cli_.Post("/content_receiver", "content", "text/plain"); + ASSERT_TRUE(res != nullptr); + ASSERT_EQ(200, res->status); + ASSERT_EQ("content", res->body); +} + +TEST_F(ServerTest, PostContentReceiverGzip) { + auto res = cli_.Post("/content_receiver", "content", "text/plain", true); + ASSERT_TRUE(res != nullptr); + ASSERT_EQ(200, res->status); + ASSERT_EQ("content", res->body); +} + +TEST_F(ServerTest, PutContentReceiver) { + auto res = cli_.Put("/content_receiver", "content", "text/plain"); + ASSERT_TRUE(res != nullptr); + ASSERT_EQ(200, res->status); + ASSERT_EQ("content", res->body); +} + +TEST_F(ServerTest, PatchContentReceiver) { + auto res = cli_.Patch("/content_receiver", "content", "text/plain"); + ASSERT_TRUE(res != nullptr); + ASSERT_EQ(200, res->status); + ASSERT_EQ("content", res->body); +} + TEST_F(ServerTest, HTTP2Magic) { Request req; req.method = "PRI"; @@ -1501,7 +1581,10 @@ TEST_F(ServerTest, GzipWithContentReceiver) { std::string body; auto res = cli_.Get("/gzip", headers, [&](const char *data, uint64_t data_length, - uint64_t /*offset*/, uint64_t /*content_length*/) { + uint64_t offset, uint64_t content_length) { + EXPECT_EQ(data_length, 100); + EXPECT_EQ(offset, 0); + EXPECT_EQ(content_length, 0); body.append(data, data_length); return true; }); @@ -1521,7 +1604,10 @@ TEST_F(ServerTest, GzipWithContentReceiverWithoutAcceptEncoding) { std::string body; auto res = cli_.Get("/gzip", headers, [&](const char *data, uint64_t data_length, - uint64_t /*offset*/, uint64_t /*content_length*/) { + uint64_t offset, uint64_t content_length) { + EXPECT_EQ(data_length, 100); + EXPECT_EQ(offset, 0); + EXPECT_EQ(content_length, 100); body.append(data, data_length); return true; }); @@ -1557,7 +1643,10 @@ TEST_F(ServerTest, NoGzipWithContentReceiver) { std::string body; auto res = cli_.Get("/nogzip", headers, [&](const char *data, uint64_t data_length, - uint64_t /*offset*/, uint64_t /*content_length*/) { + uint64_t offset, uint64_t content_length) { + EXPECT_EQ(data_length, 100); + EXPECT_EQ(offset, 0); + EXPECT_EQ(content_length, 100); body.append(data, data_length); return true; });