From 2823a94fc17f59fa0d874e5d91521e9dadbea8c7 Mon Sep 17 00:00:00 2001 From: yhirose Date: Mon, 5 Aug 2019 09:40:23 +0900 Subject: [PATCH] Added 'resource_releaser' for content provider --- httplib.h | 175 +++++++++++++++++++++++++++++++-------------------- test/test.cc | 138 ++++++++++++++++------------------------ 2 files changed, 159 insertions(+), 154 deletions(-) diff --git a/httplib.h b/httplib.h index 56c8b68..df31605 100644 --- a/httplib.h +++ b/httplib.h @@ -128,13 +128,15 @@ typedef std::smatch Match; typedef std::function Out; -typedef std::function Done; +typedef std::function Done; typedef std::function ContentProvider; -typedef Out ContentReceiver; +typedef std::function + ContentReceiver; typedef std::function Progress; @@ -193,9 +195,6 @@ struct Response { Headers headers; std::string body; - ContentProvider content_provider; - uint64_t content_length; - ContentReceiver content_receiver; Progress progress; @@ -209,18 +208,25 @@ struct Response { void set_redirect(const char *uri); void set_content(const char *s, size_t n, const char *content_type); void set_content(const std::string &s, const char *content_type); - void set_content_producer(uint64_t length, ContentProvider producer); - void set_chunked_content_producer( - std::function producer); void set_content_provider( uint64_t length, - std::function provider); + std::function provider, + std::function resource_releaser = []{}); void set_chunked_content_provider( - std::function provider); + std::function provider, + std::function resource_releaser = []{}); - Response() : status(-1), content_length(0) {} + Response() : status(-1), content_provider_resource_length(0) {} + + ~Response() { + if (content_provider_resource_releaser) { content_provider_resource_releaser(); } + } + + uint64_t content_provider_resource_length; + ContentProvider content_provider; + std::function content_provider_resource_releaser; }; class Stream { @@ -272,7 +278,7 @@ class TaskQueue { public: TaskQueue() {} virtual ~TaskQueue() {} - virtual void enqueue(std::function fn) = 0; + virtual void enqueue(std::function fn) = 0; virtual void shutdown() = 0; }; @@ -366,7 +372,7 @@ public: Threads() : running_threads_(0) {} virtual ~Threads() {} - virtual void enqueue(std::function fn) override { + virtual void enqueue(std::function fn) override { std::thread([=]() { { std::lock_guard guard(running_threads_mutex_); @@ -456,6 +462,9 @@ private: bool parse_request_line(const char *s, Request &req); bool write_response(Stream &strm, bool last_connection, const Request &req, Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); virtual bool read_and_close_socket(socket_t sock); @@ -1175,10 +1184,10 @@ public: bool is_valid() const { return is_valid_; } template - bool decompress(const char *data, size_t data_len, T callback) { + bool decompress(const char *data, size_t data_length, T callback) { int ret = Z_OK; - strm.avail_in = data_len; + strm.avail_in = data_length; strm.next_in = (Bytef *)data; const auto bufsiz = 16384; @@ -1195,7 +1204,7 @@ public: case Z_MEM_ERROR: inflateEnd(&strm); return false; } - callback(buff, bufsiz - strm.avail_out); + if (!callback(buff, bufsiz - strm.avail_out)) { return false; } } while (strm.avail_out == 0); return ret == Z_STREAM_END; @@ -1250,9 +1259,12 @@ inline bool read_headers(Stream &strm, Headers &headers) { return true; } -template +typedef std::function + ContentReceiverCore; + inline bool read_content_with_length(Stream &strm, size_t len, - Progress progress, T callback) { + Progress progress, + ContentReceiverCore out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; size_t r = 0; @@ -1260,7 +1272,7 @@ inline bool read_content_with_length(Stream &strm, size_t len, auto n = strm.read(buf, std::min((len - r), CPPHTTPLIB_RECV_BUFSIZ)); if (n <= 0) { return false; } - callback(buf, n); + if (!out(buf, n)) { return false; } r += n; @@ -1282,8 +1294,7 @@ inline void skip_content_with_length(Stream &strm, size_t len) { } } -template -inline bool read_content_without_length(Stream &strm, T callback) { +inline bool read_content_without_length(Stream &strm, ContentReceiverCore out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; for (;;) { auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); @@ -1292,14 +1303,13 @@ inline bool read_content_without_length(Stream &strm, T callback) { } else if (n == 0) { return true; } - callback(buf, n); + if (!out(buf, n)) { return false; } } return true; } -template -inline bool read_content_chunked(Stream &strm, T callback) { +inline bool read_content_chunked(Stream &strm, ContentReceiverCore out) { const auto bufsiz = 16; char buf[bufsiz]; @@ -1310,7 +1320,7 @@ inline bool read_content_chunked(Stream &strm, T callback) { auto chunk_len = std::stoi(reader.ptr(), 0, 16); while (chunk_len > 0) { - if (!read_content_with_length(strm, chunk_len, nullptr, callback)) { + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { return false; } @@ -1336,11 +1346,13 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) { "chunked"); } -template +template bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status, - Progress progress, U callback) { + Progress progress, ContentReceiverCore receiver) { - ContentReceiver out = [&](const char *buf, size_t n) { callback(buf, n); }; + ContentReceiverCore out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; #ifdef CPPHTTPLIB_ZLIB_SUPPORT detail::decompressor decompressor; @@ -1352,8 +1364,8 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status, if (x.get_header_value("Content-Encoding") == "gzip") { out = [&](const char *buf, size_t n) { - decompressor.decompress( - buf, n, [&](const char *buf, size_t n) { callback(buf, n); }); + return decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); }; } #else @@ -1791,7 +1803,7 @@ get_range_offset_and_length(const Request &req, const Response &res, size_t index) { auto r = req.ranges[index]; - if (r.second == -1) { r.second = res.content_length - 1; } + if (r.second == -1) { r.second = res.content_provider_resource_length - 1; } return std::make_pair(r.first, r.second - r.first + 1); } @@ -1923,19 +1935,23 @@ inline void Response::set_content(const std::string &s, inline void Response::set_content_provider( uint64_t length, - std::function provider) { + std::function provider, + std::function resource_releaser) { assert(length > 0); - content_length = length; + content_provider_resource_length = length; content_provider = [provider](uint64_t offset, uint64_t length, Out out, Done) { provider(offset, length, out); }; + content_provider_resource_releaser = resource_releaser; } inline void Response::set_chunked_content_provider( - std::function provider) { - content_length = 0; + std::function provider, + std::function resource_releaser) { + content_provider_resource_length = 0; content_provider = [provider](uint64_t offset, uint64_t, Out out, Done done) { provider(offset, out, done); }; + content_provider_resource_releaser = resource_releaser; } // Rstream implementation @@ -2184,17 +2200,17 @@ inline bool Server::write_response(Stream &strm, bool last_connection, } if (res.body.empty()) { - if (res.content_length > 0) { + if (res.content_provider_resource_length > 0) { uint64_t length = 0; if (req.ranges.empty()) { - length = res.content_length; + length = res.content_provider_resource_length; } else if (req.ranges.size() == 1) { auto offsets = - detail::get_range_offset_and_length(req, res.content_length, 0); + detail::get_range_offset_and_length(req, res.content_provider_resource_length, 0); auto offset = offsets.first; length = offsets.second; auto content_range = detail::make_content_range_header_field( - offset, length, res.content_length); + offset, length, res.content_provider_resource_length); res.set_header("Content-Range", content_range); } else { length = detail::get_multipart_ranges_data_length(req, res, boundary, @@ -2247,31 +2263,9 @@ inline bool Server::write_response(Stream &strm, bool last_connection, if (!res.body.empty()) { if (!strm.write(res.body)) { return false; } } else if (res.content_provider) { - if (res.content_length) { - if (req.ranges.empty()) { - if (detail::write_content(strm, res.content_provider, 0, - res.content_length) < 0) { - return false; - } - } else if (req.ranges.size() == 1) { - auto offsets = - detail::get_range_offset_and_length(req, res.content_length, 0); - auto offset = offsets.first; - auto length = offsets.second; - if (detail::write_content(strm, res.content_provider, offset, - length) < 0) { - return false; - } - } else { - if (!detail::write_multipart_ranges_data(strm, req, res, boundary, - content_type)) { - return false; - } - } - } else { - if (detail::write_content_chunked(strm, res.content_provider) < 0) { - return false; - } + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; } } } @@ -2282,6 +2276,39 @@ inline bool Server::write_response(Stream &strm, bool last_connection, return true; } +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_provider_resource_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_provider_resource_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_provider_resource_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } else { + if (detail::write_content_chunked(strm, res.content_provider) < 0) { + return false; + } + } + return true; +} + inline bool Server::handle_file_request(Request &req, Response &res) { if (!base_dir_.empty() && detail::is_valid_path(req.path)) { std::string path = base_dir_ + req.path; @@ -2459,9 +2486,11 @@ Server::process_request(Stream &strm, bool last_connection, // Body if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - if (!detail::read_content( - strm, req, payload_max_length_, res.status, Progress(), - [&](const char *buf, size_t n) { req.body.append(buf, n); })) { + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), [&](const char *buf, size_t n) { + req.body.append(buf, n); + return true; + })) { return write_response(strm, last_connection, req, res); } @@ -2643,12 +2672,20 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, // Body if (req.method != "HEAD") { - ContentReceiver out = [&](const char *buf, size_t n) { + detail::ContentReceiverCore out = [&](const char *buf, size_t n) { res.body.append(buf, n); + return true; }; if (res.content_receiver) { - out = [&](const char *buf, size_t n) { res.content_receiver(buf, n); }; + auto offset = std::make_shared(); + auto length = get_header_value_uint64(res.headers, "Content-Length", 0); + auto receiver = res.content_receiver; + out = [offset, length, receiver](const char *buf, size_t n) { + auto ret = receiver(buf, n, *offset, length); + (*offset) += n; + return ret; + }; } int dummy_status; diff --git a/test/test.cc b/test/test.cc index 87e4e97..bff59cf 100644 --- a/test/test.cc +++ b/test/test.cc @@ -229,7 +229,10 @@ TEST(ChunkedEncodingTest, WithContentReceiver) { std::string body; auto res = cli.Get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137", - [&](const char *data, size_t len) { body.append(data, len); }); + [&](const char *data, uint64_t data_length, uint64_t, uint64_t) { + body.append(data, data_length); + return true; + }); ASSERT_TRUE(res != nullptr); std::string out; @@ -508,23 +511,30 @@ protected: .Get("/streamed", [&](const Request & /*req*/, Response &res) { res.set_content_provider( - 6, [](uint64_t offset, uint64_t /*length*/, Out out) { - if (offset < 3) { - out("a", 1); - } else { - out("b", 1); - } + 6, + [](uint64_t offset, uint64_t /*length*/, Out out) { + out(offset < 3 ? "a" : "b", 1); }); }) .Get("/streamed-with-range", [&](const Request & /*req*/, Response &res) { - auto data = std::make_shared("abcdefg"); + auto data = new std::string("abcdefg"); res.set_content_provider( data->size(), [data](uint64_t offset, uint64_t length, Out out) { const uint64_t DATA_CHUNK_SIZE = 4; const auto &d = *data; out(&d[offset], std::min(length, DATA_CHUNK_SIZE)); + }, + [data] { delete data; }); + }) + .Get("/streamed-cancel", + [&](const Request & /*req*/, Response &res) { + res.set_content_provider( + uint64_t(-1), + [](uint64_t /*offset*/, uint64_t /*length*/, Out out) { + std::string data = "data_chunk"; + out(data.data(), data.size()); }); }) .Get("/with-range", @@ -849,85 +859,29 @@ TEST_F(ServerTest, EmptyRequest) { } TEST_F(ServerTest, LongRequest) { - auto res = - cli_.Get("/TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/__ok__"); + std::string request; + for (size_t i = 0; i < 545; i++) { + request += "/TooLongRequest"; + } + request += "OK"; + + auto res = cli_.Get(request.c_str()); ASSERT_TRUE(res != nullptr); EXPECT_EQ(404, res->status); } TEST_F(ServerTest, TooLongRequest) { - auto res = - cli_.Get("/TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/" - "TooLongRequest/TooLongRequest/TooLongRequest/__ng___"); + std::string request; + for (size_t i = 0; i < 545; i++) { + request += "/TooLongRequest"; + } + request += "_NG"; + + auto res = cli_.Get(request.c_str()); ASSERT_TRUE(res != nullptr); - EXPECT_EQ(404, res->status); + EXPECT_EQ(414, res->status); } TEST_F(ServerTest, LongHeader) { @@ -1169,6 +1123,14 @@ TEST_F(ServerTest, GetStreamedWithRangeMultipart) { EXPECT_EQ(269, res->body.size()); } +TEST_F(ServerTest, GetStreamedEndless) { + auto res = cli_.Get("/streamed-cancel", + [](const char * /*data*/, uint64_t /*data_length*/, + uint64_t offset, + uint64_t /*content_length*/) { return offset < 100; }); + ASSERT_TRUE(res == nullptr); +} + TEST_F(ServerTest, GetWithRange1) { auto res = cli_.Get("/with-range", {{make_range_header({{3, 5}})}}); ASSERT_TRUE(res != nullptr); @@ -1339,9 +1301,12 @@ TEST_F(ServerTest, GzipWithContentReceiver) { Headers headers; headers.emplace("Accept-Encoding", "gzip, deflate"); std::string body; - auto res = cli_.Get("/gzip", headers, [&](const char *data, size_t len) { - body.append(data, len); - }); + auto res = cli_.Get("/gzip", headers, + [&](const char *data, uint64_t data_length, + uint64_t /*offset*/, uint64_t /*content_length*/) { + body.append(data, data_length); + return true; + }); ASSERT_TRUE(res != nullptr); EXPECT_EQ("gzip", res->get_header_value("Content-Encoding")); @@ -1372,9 +1337,12 @@ TEST_F(ServerTest, NoGzipWithContentReceiver) { Headers headers; headers.emplace("Accept-Encoding", "gzip, deflate"); std::string body; - auto res = cli_.Get("/nogzip", headers, [&](const char *data, size_t len) { - body.append(data, len); - }); + auto res = cli_.Get("/nogzip", headers, + [&](const char *data, uint64_t data_length, + uint64_t /*offset*/, uint64_t /*content_length*/) { + body.append(data, data_length); + return true; + }); ASSERT_TRUE(res != nullptr); EXPECT_EQ(false, res->has_header("Content-Encoding"));