Added 'resource_releaser' for content provider

This commit is contained in:
yhirose 2019-08-05 09:40:23 +09:00
parent 5a13539e57
commit 2823a94fc1
2 changed files with 159 additions and 154 deletions

175
httplib.h
View file

@ -128,13 +128,15 @@ typedef std::smatch Match;
typedef std::function<void(const char *data, uint64_t len)> Out;
typedef std::function<void(void)> Done;
typedef std::function<void()> Done;
typedef std::function<void(uint64_t offset, uint64_t length, Out out,
Done done)>
ContentProvider;
typedef Out ContentReceiver;
typedef std::function<bool(const char *data, uint64_t data_length,
uint64_t offset, uint64_t content_length)>
ContentReceiver;
typedef std::function<bool(uint64_t current, uint64_t total)> Progress;
@ -193,9 +195,6 @@ struct Response {
Headers headers;
std::string body;
ContentProvider content_provider;
uint64_t content_length;
ContentReceiver content_receiver;
Progress progress;
@ -209,18 +208,25 @@ struct Response {
void set_redirect(const char *uri);
void set_content(const char *s, size_t n, const char *content_type);
void set_content(const std::string &s, const char *content_type);
void set_content_producer(uint64_t length, ContentProvider producer);
void set_chunked_content_producer(
std::function<std::string(uint64_t offset)> producer);
void set_content_provider(
uint64_t length,
std::function<void(uint64_t offset, uint64_t length, Out out)> provider);
std::function<void(uint64_t offset, uint64_t length, Out out)> provider,
std::function<void()> resource_releaser = []{});
void set_chunked_content_provider(
std::function<void(uint64_t offset, Out out, Done done)> provider);
std::function<void(uint64_t offset, Out out, Done done)> provider,
std::function<void()> resource_releaser = []{});
Response() : status(-1), content_length(0) {}
Response() : status(-1), content_provider_resource_length(0) {}
~Response() {
if (content_provider_resource_releaser) { content_provider_resource_releaser(); }
}
uint64_t content_provider_resource_length;
ContentProvider content_provider;
std::function<void()> content_provider_resource_releaser;
};
class Stream {
@ -272,7 +278,7 @@ class TaskQueue {
public:
TaskQueue() {}
virtual ~TaskQueue() {}
virtual void enqueue(std::function<void(void)> fn) = 0;
virtual void enqueue(std::function<void()> fn) = 0;
virtual void shutdown() = 0;
};
@ -366,7 +372,7 @@ public:
Threads() : running_threads_(0) {}
virtual ~Threads() {}
virtual void enqueue(std::function<void(void)> fn) override {
virtual void enqueue(std::function<void()> fn) override {
std::thread([=]() {
{
std::lock_guard<std::mutex> guard(running_threads_mutex_);
@ -456,6 +462,9 @@ private:
bool parse_request_line(const char *s, Request &req);
bool write_response(Stream &strm, bool last_connection, const Request &req,
Response &res);
bool write_content_with_provider(Stream &strm, const Request &req,
Response &res, const std::string &boundary,
const std::string &content_type);
virtual bool read_and_close_socket(socket_t sock);
@ -1175,10 +1184,10 @@ public:
bool is_valid() const { return is_valid_; }
template <typename T>
bool decompress(const char *data, size_t data_len, T callback) {
bool decompress(const char *data, size_t data_length, T callback) {
int ret = Z_OK;
strm.avail_in = data_len;
strm.avail_in = data_length;
strm.next_in = (Bytef *)data;
const auto bufsiz = 16384;
@ -1195,7 +1204,7 @@ public:
case Z_MEM_ERROR: inflateEnd(&strm); return false;
}
callback(buff, bufsiz - strm.avail_out);
if (!callback(buff, bufsiz - strm.avail_out)) { return false; }
} while (strm.avail_out == 0);
return ret == Z_STREAM_END;
@ -1250,9 +1259,12 @@ inline bool read_headers(Stream &strm, Headers &headers) {
return true;
}
template <typename T>
typedef std::function<bool(const char *data, uint64_t data_length)>
ContentReceiverCore;
inline bool read_content_with_length(Stream &strm, size_t len,
Progress progress, T callback) {
Progress progress,
ContentReceiverCore out) {
char buf[CPPHTTPLIB_RECV_BUFSIZ];
size_t r = 0;
@ -1260,7 +1272,7 @@ inline bool read_content_with_length(Stream &strm, size_t len,
auto n = strm.read(buf, std::min((len - r), CPPHTTPLIB_RECV_BUFSIZ));
if (n <= 0) { return false; }
callback(buf, n);
if (!out(buf, n)) { return false; }
r += n;
@ -1282,8 +1294,7 @@ inline void skip_content_with_length(Stream &strm, size_t len) {
}
}
template <typename T>
inline bool read_content_without_length(Stream &strm, T callback) {
inline bool read_content_without_length(Stream &strm, ContentReceiverCore out) {
char buf[CPPHTTPLIB_RECV_BUFSIZ];
for (;;) {
auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ);
@ -1292,14 +1303,13 @@ inline bool read_content_without_length(Stream &strm, T callback) {
} else if (n == 0) {
return true;
}
callback(buf, n);
if (!out(buf, n)) { return false; }
}
return true;
}
template <typename T>
inline bool read_content_chunked(Stream &strm, T callback) {
inline bool read_content_chunked(Stream &strm, ContentReceiverCore out) {
const auto bufsiz = 16;
char buf[bufsiz];
@ -1310,7 +1320,7 @@ inline bool read_content_chunked(Stream &strm, T callback) {
auto chunk_len = std::stoi(reader.ptr(), 0, 16);
while (chunk_len > 0) {
if (!read_content_with_length(strm, chunk_len, nullptr, callback)) {
if (!read_content_with_length(strm, chunk_len, nullptr, out)) {
return false;
}
@ -1336,11 +1346,13 @@ inline bool is_chunked_transfer_encoding(const Headers &headers) {
"chunked");
}
template <typename T, typename U>
template <typename T>
bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status,
Progress progress, U callback) {
Progress progress, ContentReceiverCore receiver) {
ContentReceiver out = [&](const char *buf, size_t n) { callback(buf, n); };
ContentReceiverCore out = [&](const char *buf, size_t n) {
return receiver(buf, n);
};
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
detail::decompressor decompressor;
@ -1352,8 +1364,8 @@ bool read_content(Stream &strm, T &x, uint64_t payload_max_length, int &status,
if (x.get_header_value("Content-Encoding") == "gzip") {
out = [&](const char *buf, size_t n) {
decompressor.decompress(
buf, n, [&](const char *buf, size_t n) { callback(buf, n); });
return decompressor.decompress(
buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); });
};
}
#else
@ -1791,7 +1803,7 @@ get_range_offset_and_length(const Request &req, const Response &res,
size_t index) {
auto r = req.ranges[index];
if (r.second == -1) { r.second = res.content_length - 1; }
if (r.second == -1) { r.second = res.content_provider_resource_length - 1; }
return std::make_pair(r.first, r.second - r.first + 1);
}
@ -1923,19 +1935,23 @@ inline void Response::set_content(const std::string &s,
inline void Response::set_content_provider(
uint64_t length,
std::function<void(uint64_t offset, uint64_t length, Out out)> provider) {
std::function<void(uint64_t offset, uint64_t length, Out out)> provider,
std::function<void()> resource_releaser) {
assert(length > 0);
content_length = length;
content_provider_resource_length = length;
content_provider = [provider](uint64_t offset, uint64_t length, Out out,
Done) { provider(offset, length, out); };
content_provider_resource_releaser = resource_releaser;
}
inline void Response::set_chunked_content_provider(
std::function<void(uint64_t offset, Out out, Done done)> provider) {
content_length = 0;
std::function<void(uint64_t offset, Out out, Done done)> provider,
std::function<void()> resource_releaser) {
content_provider_resource_length = 0;
content_provider = [provider](uint64_t offset, uint64_t, Out out, Done done) {
provider(offset, out, done);
};
content_provider_resource_releaser = resource_releaser;
}
// Rstream implementation
@ -2184,17 +2200,17 @@ inline bool Server::write_response(Stream &strm, bool last_connection,
}
if (res.body.empty()) {
if (res.content_length > 0) {
if (res.content_provider_resource_length > 0) {
uint64_t length = 0;
if (req.ranges.empty()) {
length = res.content_length;
length = res.content_provider_resource_length;
} else if (req.ranges.size() == 1) {
auto offsets =
detail::get_range_offset_and_length(req, res.content_length, 0);
detail::get_range_offset_and_length(req, res.content_provider_resource_length, 0);
auto offset = offsets.first;
length = offsets.second;
auto content_range = detail::make_content_range_header_field(
offset, length, res.content_length);
offset, length, res.content_provider_resource_length);
res.set_header("Content-Range", content_range);
} else {
length = detail::get_multipart_ranges_data_length(req, res, boundary,
@ -2247,31 +2263,9 @@ inline bool Server::write_response(Stream &strm, bool last_connection,
if (!res.body.empty()) {
if (!strm.write(res.body)) { return false; }
} else if (res.content_provider) {
if (res.content_length) {
if (req.ranges.empty()) {
if (detail::write_content(strm, res.content_provider, 0,
res.content_length) < 0) {
return false;
}
} else if (req.ranges.size() == 1) {
auto offsets =
detail::get_range_offset_and_length(req, res.content_length, 0);
auto offset = offsets.first;
auto length = offsets.second;
if (detail::write_content(strm, res.content_provider, offset,
length) < 0) {
return false;
}
} else {
if (!detail::write_multipart_ranges_data(strm, req, res, boundary,
content_type)) {
return false;
}
}
} else {
if (detail::write_content_chunked(strm, res.content_provider) < 0) {
return false;
}
if (!write_content_with_provider(strm, req, res, boundary,
content_type)) {
return false;
}
}
}
@ -2282,6 +2276,39 @@ inline bool Server::write_response(Stream &strm, bool last_connection,
return true;
}
inline bool
Server::write_content_with_provider(Stream &strm, const Request &req,
Response &res, const std::string &boundary,
const std::string &content_type) {
if (res.content_provider_resource_length) {
if (req.ranges.empty()) {
if (detail::write_content(strm, res.content_provider, 0,
res.content_provider_resource_length) < 0) {
return false;
}
} else if (req.ranges.size() == 1) {
auto offsets =
detail::get_range_offset_and_length(req, res.content_provider_resource_length, 0);
auto offset = offsets.first;
auto length = offsets.second;
if (detail::write_content(strm, res.content_provider, offset, length) <
0) {
return false;
}
} else {
if (!detail::write_multipart_ranges_data(strm, req, res, boundary,
content_type)) {
return false;
}
}
} else {
if (detail::write_content_chunked(strm, res.content_provider) < 0) {
return false;
}
}
return true;
}
inline bool Server::handle_file_request(Request &req, Response &res) {
if (!base_dir_.empty() && detail::is_valid_path(req.path)) {
std::string path = base_dir_ + req.path;
@ -2459,9 +2486,11 @@ Server::process_request(Stream &strm, bool last_connection,
// Body
if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") {
if (!detail::read_content(
strm, req, payload_max_length_, res.status, Progress(),
[&](const char *buf, size_t n) { req.body.append(buf, n); })) {
if (!detail::read_content(strm, req, payload_max_length_, res.status,
Progress(), [&](const char *buf, size_t n) {
req.body.append(buf, n);
return true;
})) {
return write_response(strm, last_connection, req, res);
}
@ -2643,12 +2672,20 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
// Body
if (req.method != "HEAD") {
ContentReceiver out = [&](const char *buf, size_t n) {
detail::ContentReceiverCore out = [&](const char *buf, size_t n) {
res.body.append(buf, n);
return true;
};
if (res.content_receiver) {
out = [&](const char *buf, size_t n) { res.content_receiver(buf, n); };
auto offset = std::make_shared<uint64_t>();
auto length = get_header_value_uint64(res.headers, "Content-Length", 0);
auto receiver = res.content_receiver;
out = [offset, length, receiver](const char *buf, size_t n) {
auto ret = receiver(buf, n, *offset, length);
(*offset) += n;
return ret;
};
}
int dummy_status;

View file

@ -229,7 +229,10 @@ TEST(ChunkedEncodingTest, WithContentReceiver) {
std::string body;
auto res =
cli.Get("/httpgallery/chunked/chunkedimage.aspx?0.4153841143030137",
[&](const char *data, size_t len) { body.append(data, len); });
[&](const char *data, uint64_t data_length, uint64_t, uint64_t) {
body.append(data, data_length);
return true;
});
ASSERT_TRUE(res != nullptr);
std::string out;
@ -508,23 +511,30 @@ protected:
.Get("/streamed",
[&](const Request & /*req*/, Response &res) {
res.set_content_provider(
6, [](uint64_t offset, uint64_t /*length*/, Out out) {
if (offset < 3) {
out("a", 1);
} else {
out("b", 1);
}
6,
[](uint64_t offset, uint64_t /*length*/, Out out) {
out(offset < 3 ? "a" : "b", 1);
});
})
.Get("/streamed-with-range",
[&](const Request & /*req*/, Response &res) {
auto data = std::make_shared<std::string>("abcdefg");
auto data = new std::string("abcdefg");
res.set_content_provider(
data->size(),
[data](uint64_t offset, uint64_t length, Out out) {
const uint64_t DATA_CHUNK_SIZE = 4;
const auto &d = *data;
out(&d[offset], std::min(length, DATA_CHUNK_SIZE));
},
[data] { delete data; });
})
.Get("/streamed-cancel",
[&](const Request & /*req*/, Response &res) {
res.set_content_provider(
uint64_t(-1),
[](uint64_t /*offset*/, uint64_t /*length*/, Out out) {
std::string data = "data_chunk";
out(data.data(), data.size());
});
})
.Get("/with-range",
@ -849,85 +859,29 @@ TEST_F(ServerTest, EmptyRequest) {
}
TEST_F(ServerTest, LongRequest) {
auto res =
cli_.Get("/TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/__ok__");
std::string request;
for (size_t i = 0; i < 545; i++) {
request += "/TooLongRequest";
}
request += "OK";
auto res = cli_.Get(request.c_str());
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(404, res->status);
}
TEST_F(ServerTest, TooLongRequest) {
auto res =
cli_.Get("/TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/TooLongRequest/"
"TooLongRequest/TooLongRequest/TooLongRequest/__ng___");
std::string request;
for (size_t i = 0; i < 545; i++) {
request += "/TooLongRequest";
}
request += "_NG";
auto res = cli_.Get(request.c_str());
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(404, res->status);
EXPECT_EQ(414, res->status);
}
TEST_F(ServerTest, LongHeader) {
@ -1169,6 +1123,14 @@ TEST_F(ServerTest, GetStreamedWithRangeMultipart) {
EXPECT_EQ(269, res->body.size());
}
TEST_F(ServerTest, GetStreamedEndless) {
auto res = cli_.Get("/streamed-cancel",
[](const char * /*data*/, uint64_t /*data_length*/,
uint64_t offset,
uint64_t /*content_length*/) { return offset < 100; });
ASSERT_TRUE(res == nullptr);
}
TEST_F(ServerTest, GetWithRange1) {
auto res = cli_.Get("/with-range", {{make_range_header({{3, 5}})}});
ASSERT_TRUE(res != nullptr);
@ -1339,9 +1301,12 @@ TEST_F(ServerTest, GzipWithContentReceiver) {
Headers headers;
headers.emplace("Accept-Encoding", "gzip, deflate");
std::string body;
auto res = cli_.Get("/gzip", headers, [&](const char *data, size_t len) {
body.append(data, len);
});
auto res = cli_.Get("/gzip", headers,
[&](const char *data, uint64_t data_length,
uint64_t /*offset*/, uint64_t /*content_length*/) {
body.append(data, data_length);
return true;
});
ASSERT_TRUE(res != nullptr);
EXPECT_EQ("gzip", res->get_header_value("Content-Encoding"));
@ -1372,9 +1337,12 @@ TEST_F(ServerTest, NoGzipWithContentReceiver) {
Headers headers;
headers.emplace("Accept-Encoding", "gzip, deflate");
std::string body;
auto res = cli_.Get("/nogzip", headers, [&](const char *data, size_t len) {
body.append(data, len);
});
auto res = cli_.Get("/nogzip", headers,
[&](const char *data, uint64_t data_length,
uint64_t /*offset*/, uint64_t /*content_length*/) {
body.append(data, data_length);
return true;
});
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(false, res->has_header("Content-Encoding"));