diff --git a/httplib.h b/httplib.h index 99ca369..278e2ff 100644 --- a/httplib.h +++ b/httplib.h @@ -85,7 +85,9 @@ typedef int socket_t; */ #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 #define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 #define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH std::numeric_limits::max() namespace httplib { @@ -233,6 +235,7 @@ public: void set_logger(Logger logger); void set_keep_alive_max_count(size_t count); + void set_payload_max_length(uint64_t length); int bind_to_any_port(const char *host, int socket_flags = 0); bool listen_after_bind(); @@ -247,6 +250,7 @@ protected: bool &connection_close); size_t keep_alive_max_count_; + size_t payload_max_length_; private: typedef std::vector> Handlers; @@ -762,6 +766,7 @@ inline const char *status_message(int status) { case 400: return "Bad Request"; case 403: return "Forbidden"; case 404: return "Not Found"; + case 413: return "Payload Too Large"; case 414: return "Request-URI Too Long"; case 415: return "Unsupported Media Type"; default: @@ -782,12 +787,12 @@ inline const char *get_header_value(const Headers &headers, const char *key, } inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, - int def = 0) { - auto it = headers.find(key); - if (it != headers.end()) { - return std::strtoull(it->second.data(), nullptr, 10); - } - return def; + int def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; } inline bool read_headers(Stream &strm, Headers &headers) { @@ -881,7 +886,9 @@ inline bool read_content_chunked(Stream &strm, std::string &out) { } template -bool read_content(Stream &strm, T &x, Progress progress = Progress()) { +bool read_content(Stream &strm, T &x, uint64_t payload_max_length, + bool &exceed_payload_max_length, + Progress progress = Progress()) { if (has_header(x.headers, "Content-Length")) { auto len = get_header_value_uint64(x.headers, "Content-Length", 0); if (len == 0) { @@ -891,6 +898,15 @@ bool read_content(Stream &strm, T &x, Progress progress = Progress()) { return read_content_chunked(strm, x.body); } } + + if ((len > payload_max_length) || + // For 32-bit platform + (sizeof(size_t) < sizeof(uint64_t) && + len > std::numeric_limits::max())) { + exceed_payload_max_length = true; + return false; + } + return read_content_with_length(strm, x.body, len, progress); } else { const auto &encoding = @@ -1427,8 +1443,9 @@ inline const std::string &BufferStream::get_buffer() const { return buffer; } // HTTP server implementation inline Server::Server() - : keep_alive_max_count_(5), is_running_(false), svr_sock_(INVALID_SOCKET), - running_threads_(0) { + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET), running_threads_(0) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); #endif @@ -1484,6 +1501,10 @@ inline void Server::set_keep_alive_max_count(size_t count) { keep_alive_max_count_ = count; } +inline void Server::set_payload_max_length(uint64_t length) { + payload_max_length_ = length; +} + inline int Server::bind_to_any_port(const char *host, int socket_flags) { return bind_internal(host, 0, socket_flags); } @@ -1702,8 +1723,7 @@ inline bool Server::listen_internal() { std::lock_guard guard(running_threads_mutex_); running_threads_--; } - }) - .detach(); + }).detach(); } // TODO: Use thread pool... @@ -1789,10 +1809,12 @@ inline bool Server::process_request(Stream &strm, bool last_connection, // Body if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - if (!detail::read_content(strm, req)) { - res.status = 400; + bool exceed_payload_max_length = false; + if (!detail::read_content(strm, req, payload_max_length_, + exceed_payload_max_length)) { + res.status = exceed_payload_max_length ? 413 : 400; write_response(strm, last_connection, req, res); - return true; + return !exceed_payload_max_length; } const auto &content_type = req.get_header_value("Content-Type"); @@ -1975,7 +1997,11 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res, // Body if (req.method != "HEAD") { - if (!detail::read_content(strm, res, req.progress)) { return false; } + bool exceed_payload_max_length = false; + if (!detail::read_content(strm, res, std::numeric_limits::max(), + exceed_payload_max_length, req.progress)) { + return false; + } if (res.get_header_value("Content-Encoding") == "gzip") { #ifdef CPPHTTPLIB_ZLIB_SUPPORT diff --git a/test/test.cc b/test/test.cc index deb6f6d..bb71ec8 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1288,6 +1288,56 @@ TEST_F(ServerUpDownTest, QuickStartStop) { // --gtest_filter=ServerUpDownTest.QuickStartStop --gtest_repeat=1000 } +class PayloadMaxLengthTest : public ::testing::Test { +protected: + PayloadMaxLengthTest() + : cli_(HOST, PORT) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + , + svr_(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE) +#endif + { + } + + virtual void SetUp() { + svr_.set_payload_max_length(8); + + svr_.Post("/test", [&](const Request & /*req*/, Response &res) { + res.set_content("test", "text/plain"); + }); + + t_ = thread([&]() { ASSERT_TRUE(svr_.listen(HOST, PORT)); }); + + while (!svr_.is_running()) { + msleep(1); + } + } + + virtual void TearDown() { + svr_.stop(); + t_.join(); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli_; + SSLServer svr_; +#else + Client cli_; + Server svr_; +#endif + thread t_; +}; + +TEST_F(PayloadMaxLengthTest, ExceedLimit) { + auto res = cli_.Post("/test", "123456789", "text/plain"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(413, res->status); + + res = cli_.Post("/test", "12345678", "text/plain"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); +} + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT TEST(SSLClientTest, ServerNameIndication) { SSLClient cli("httpbin.org", 443);