From 457a5a7501fba18be3dcdf8cce776ccce24e71ff Mon Sep 17 00:00:00 2001 From: yhirose Date: Sun, 19 Jul 2020 17:44:45 -0400 Subject: [PATCH] Added compressor class --- httplib.h | 150 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 97 insertions(+), 53 deletions(-) diff --git a/httplib.h b/httplib.h index 9865a6e..b45bcaa 100644 --- a/httplib.h +++ b/httplib.h @@ -143,9 +143,9 @@ using ssize_t = int; #endif // NOMINMAX #include +#include #include #include -#include #ifndef WSA_FLAG_NO_HANDLE_INHERIT #define WSA_FLAG_NO_HANDLE_INHERIT 0x80 @@ -2271,90 +2271,106 @@ inline bool can_compress(const std::string &content_type) { content_type == "application/xhtml+xml"; } -inline bool compress(std::string &content) { - z_stream strm; - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; +class compressor { +public: + compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; - auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, - Z_DEFAULT_STRATEGY); - if (ret != Z_OK) { return false; } + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; + } - strm.avail_in = static_cast(content.size()); - strm.next_in = - const_cast(reinterpret_cast(content.data())); + ~compressor() { deflateEnd(&strm_); } - std::string compressed; + template + bool compress(const char *data, size_t data_length, bool last, T callback) { + assert(is_valid_); - std::array buff{}; - do { - strm.avail_out = buff.size(); - strm.next_out = reinterpret_cast(buff.data()); - ret = deflate(&strm, Z_FINISH); - assert(ret != Z_STREAM_ERROR); - compressed.append(buff.data(), buff.size() - strm.avail_out); - } while (strm.avail_out == 0); + auto flush = last ? Z_FINISH : Z_NO_FLUSH; - assert(ret == Z_STREAM_END); - assert(strm.avail_in == 0); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); - content.swap(compressed); + int ret = Z_OK; - deflateEnd(&strm); - return true; -} + std::array buff{}; + do { + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + assert(ret != Z_STREAM_ERROR); + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm_.avail_in == 0); + return true; + } + +private: + bool is_valid_ = false; + z_stream strm_; +}; class decompressor { public: decompressor() { - std::memset(&strm, 0, sizeof(strm)); - strm.zalloc = Z_NULL; - strm.zfree = Z_NULL; - strm.opaque = Z_NULL; + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; // 15 is the value of wbits, which should be at the maximum possible value // to ensure that any gzip stream can be decoded. The offset of 32 specifies // that the stream type should be automatically detected either gzip or // deflate. - is_valid_ = inflateInit2(&strm, 32 + 15) == Z_OK; + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; } - ~decompressor() { inflateEnd(&strm); } + ~decompressor() { inflateEnd(&strm_); } bool is_valid() const { return is_valid_; } template bool decompress(const char *data, size_t data_length, T callback) { + assert(is_valid_); + int ret = Z_OK; - strm.avail_in = static_cast(data_length); - strm.next_in = const_cast(reinterpret_cast(data)); + strm_.avail_in = static_cast(data_length); + strm_.next_in = const_cast(reinterpret_cast(data)); std::array buff{}; do { - strm.avail_out = buff.size(); - strm.next_out = reinterpret_cast(buff.data()); + strm_.avail_out = buff.size(); + strm_.next_out = reinterpret_cast(buff.data()); - ret = inflate(&strm, Z_NO_FLUSH); + ret = inflate(&strm_, Z_NO_FLUSH); assert(ret != Z_STREAM_ERROR); switch (ret) { case Z_NEED_DICT: case Z_DATA_ERROR: - case Z_MEM_ERROR: inflateEnd(&strm); return false; + case Z_MEM_ERROR: inflateEnd(&strm_); return false; } - if (!callback(buff.data(), buff.size() - strm.avail_out)) { + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { return false; } - } while (strm.avail_out == 0); + } while (strm_.avail_out == 0); return ret == Z_OK || ret == Z_STREAM_END; } private: - bool is_valid_; - z_stream strm; + bool is_valid_ = false; + z_stream strm_; }; #endif @@ -3924,9 +3940,17 @@ inline bool Server::write_response(Stream &strm, bool close_connection, const auto &encodings = req.get_header_value("Accept-Encoding"); if (encodings.find("gzip") != std::string::npos && detail::can_compress(res.get_header_value("Content-Type"))) { - if (detail::compress(res.body)) { - res.set_header("Content-Encoding", "gzip"); + std::string compressed; + detail::compressor compressor; + if (!compressor.compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + return false; } + res.body.swap(compressed); + res.set_header("Content-Encoding", "gzip"); } #endif @@ -4730,26 +4754,47 @@ inline std::shared_ptr Client::send_with_content_provider( #ifdef CPPHTTPLIB_ZLIB_SUPPORT if (compress_) { + detail::compressor compressor; + if (content_provider) { + auto ok = true; size_t offset = 0; DataSink data_sink; data_sink.write = [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - offset += data_len; - }; - data_sink.is_writable = [&](void) { return true; }; + if (ok) { + auto last = offset + data_len == content_length; - while (offset < content_length) { + auto ret = compressor.compress( + data, data_len, last, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + }; + data_sink.is_writable = [&](void) { return ok && true; }; + + while (ok && offset < content_length) { if (!content_provider(offset, content_length - offset, data_sink)) { return nullptr; } } } else { - req.body = body; + if (!compressor.compress(body.data(), body.size(), true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + return nullptr; + } } - if (!detail::compress(req.body)) { return nullptr; } req.headers.emplace("Content-Encoding", "gzip"); } else #endif @@ -5821,4 +5866,3 @@ inline bool SSLClient::check_host_name(const char *pattern, } // namespace httplib #endif // CPPHTTPLIB_HTTPLIB_H -