This commit is contained in:
yhirose 2019-04-12 23:34:27 -04:00
parent 77536acef7
commit 744e8e7071
2 changed files with 91 additions and 15 deletions

View file

@ -85,7 +85,9 @@ typedef int socket_t;
*/
#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5
#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0
#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192
#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH std::numeric_limits<uint64_t>::max()
namespace httplib {
@ -233,6 +235,7 @@ public:
void set_logger(Logger logger);
void set_keep_alive_max_count(size_t count);
void set_payload_max_length(uint64_t length);
int bind_to_any_port(const char *host, int socket_flags = 0);
bool listen_after_bind();
@ -247,6 +250,7 @@ protected:
bool &connection_close);
size_t keep_alive_max_count_;
size_t payload_max_length_;
private:
typedef std::vector<std::pair<std::regex, Handler>> Handlers;
@ -762,6 +766,7 @@ inline const char *status_message(int status) {
case 400: return "Bad Request";
case 403: return "Forbidden";
case 404: return "Not Found";
case 413: return "Payload Too Large";
case 414: return "Request-URI Too Long";
case 415: return "Unsupported Media Type";
default:
@ -782,12 +787,12 @@ inline const char *get_header_value(const Headers &headers, const char *key,
}
inline uint64_t get_header_value_uint64(const Headers &headers, const char *key,
int def = 0) {
auto it = headers.find(key);
if (it != headers.end()) {
return std::strtoull(it->second.data(), nullptr, 10);
}
return def;
int def = 0) {
auto it = headers.find(key);
if (it != headers.end()) {
return std::strtoull(it->second.data(), nullptr, 10);
}
return def;
}
inline bool read_headers(Stream &strm, Headers &headers) {
@ -881,7 +886,9 @@ inline bool read_content_chunked(Stream &strm, std::string &out) {
}
template <typename T>
bool read_content(Stream &strm, T &x, Progress progress = Progress()) {
bool read_content(Stream &strm, T &x, uint64_t payload_max_length,
bool &exceed_payload_max_length,
Progress progress = Progress()) {
if (has_header(x.headers, "Content-Length")) {
auto len = get_header_value_uint64(x.headers, "Content-Length", 0);
if (len == 0) {
@ -891,6 +898,15 @@ bool read_content(Stream &strm, T &x, Progress progress = Progress()) {
return read_content_chunked(strm, x.body);
}
}
if ((len > payload_max_length) ||
// For 32-bit platform
(sizeof(size_t) < sizeof(uint64_t) &&
len > std::numeric_limits<size_t>::max())) {
exceed_payload_max_length = true;
return false;
}
return read_content_with_length(strm, x.body, len, progress);
} else {
const auto &encoding =
@ -1427,8 +1443,9 @@ inline const std::string &BufferStream::get_buffer() const { return buffer; }
// HTTP server implementation
inline Server::Server()
: keep_alive_max_count_(5), is_running_(false), svr_sock_(INVALID_SOCKET),
running_threads_(0) {
: keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT),
payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false),
svr_sock_(INVALID_SOCKET), running_threads_(0) {
#ifndef _WIN32
signal(SIGPIPE, SIG_IGN);
#endif
@ -1484,6 +1501,10 @@ inline void Server::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
}
inline void Server::set_payload_max_length(uint64_t length) {
payload_max_length_ = length;
}
inline int Server::bind_to_any_port(const char *host, int socket_flags) {
return bind_internal(host, 0, socket_flags);
}
@ -1702,8 +1723,7 @@ inline bool Server::listen_internal() {
std::lock_guard<std::mutex> guard(running_threads_mutex_);
running_threads_--;
}
})
.detach();
}).detach();
}
// TODO: Use thread pool...
@ -1789,10 +1809,12 @@ inline bool Server::process_request(Stream &strm, bool last_connection,
// Body
if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") {
if (!detail::read_content(strm, req)) {
res.status = 400;
bool exceed_payload_max_length = false;
if (!detail::read_content(strm, req, payload_max_length_,
exceed_payload_max_length)) {
res.status = exceed_payload_max_length ? 413 : 400;
write_response(strm, last_connection, req, res);
return true;
return !exceed_payload_max_length;
}
const auto &content_type = req.get_header_value("Content-Type");
@ -1975,7 +1997,11 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
// Body
if (req.method != "HEAD") {
if (!detail::read_content(strm, res, req.progress)) { return false; }
bool exceed_payload_max_length = false;
if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(),
exceed_payload_max_length, req.progress)) {
return false;
}
if (res.get_header_value("Content-Encoding") == "gzip") {
#ifdef CPPHTTPLIB_ZLIB_SUPPORT

View file

@ -1288,6 +1288,56 @@ TEST_F(ServerUpDownTest, QuickStartStop) {
// --gtest_filter=ServerUpDownTest.QuickStartStop --gtest_repeat=1000
}
class PayloadMaxLengthTest : public ::testing::Test {
protected:
PayloadMaxLengthTest()
: cli_(HOST, PORT)
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
,
svr_(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE)
#endif
{
}
virtual void SetUp() {
svr_.set_payload_max_length(8);
svr_.Post("/test", [&](const Request & /*req*/, Response &res) {
res.set_content("test", "text/plain");
});
t_ = thread([&]() { ASSERT_TRUE(svr_.listen(HOST, PORT)); });
while (!svr_.is_running()) {
msleep(1);
}
}
virtual void TearDown() {
svr_.stop();
t_.join();
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
SSLClient cli_;
SSLServer svr_;
#else
Client cli_;
Server svr_;
#endif
thread t_;
};
TEST_F(PayloadMaxLengthTest, ExceedLimit) {
auto res = cli_.Post("/test", "123456789", "text/plain");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(413, res->status);
res = cli_.Post("/test", "12345678", "text/plain");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(SSLClientTest, ServerNameIndication) {
SSLClient cli("httpbin.org", 443);