From b7566f6961275f465638ac6e80e861e7707a5a5c Mon Sep 17 00:00:00 2001 From: yhirose Date: Tue, 2 Feb 2021 22:09:35 -0500 Subject: [PATCH] Resolve #852 --- httplib.h | 232 ++++++++++++++++++++++++++++++--------------------- test/test.cc | 51 ++++++++--- 2 files changed, 180 insertions(+), 103 deletions(-) diff --git a/httplib.h b/httplib.h index 3a65144..c21b25e 100644 --- a/httplib.h +++ b/httplib.h @@ -390,6 +390,9 @@ struct Request { Match matches; // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT const SSL *ssl; #endif @@ -413,12 +416,9 @@ struct Request { // private members... size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; - ResponseHandler response_handler_; - ContentReceiverWithProgress content_receiver_; size_t content_length_ = 0; ContentProvider content_provider_; bool is_chunked_content_provider_ = false; - Progress progress_; size_t authorization_count_ = 0; }; @@ -794,8 +794,11 @@ enum Error { class Result { public: - Result(std::unique_ptr res, Error err) - : res_(std::move(res)), err_(err) {} + Result(std::unique_ptr &&res, Error err, + Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)) {} + // Response operator bool() const { return res_ != nullptr; } bool operator==(std::nullptr_t) const { return res_ == nullptr; } bool operator!=(std::nullptr_t) const { return res_ != nullptr; } @@ -805,11 +808,21 @@ public: Response &operator*() { return *res_; } const Response *operator->() const { return res_.get(); } Response *operator->() { return res_.get(); } + + // Error Error error() const { return err_; } + // Request Headers + bool has_request_header(const char *key) const; + std::string get_request_header_value(const char *key, size_t id = 0) const; + template + T get_request_header_value(const char *key, size_t id = 0) const; + size_t get_request_header_value_count(const char *key) const; + private: std::unique_ptr res_; Error err_; + Headers request_headers_; }; class ClientImpl { @@ -939,7 +952,7 @@ public: Result Options(const char *path); Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res, Error &error); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); size_t is_socket_open() const; @@ -993,6 +1006,8 @@ protected: bool is_open() const { return sock != INVALID_SOCKET; } }; + Result send_(Request &&req); + virtual bool create_and_connect_socket(Socket &socket, Error &error); // All of: @@ -1010,7 +1025,7 @@ protected: // concurrently with a DIFFERENT thread sending requests from the socket void lock_socket_and_shutdown_and_close(); - bool process_request(Stream &strm, const Request &req, Response &res, + bool process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); bool write_content_with_provider(Stream &strm, const Request &req, @@ -1086,13 +1101,14 @@ protected: private: socket_t create_client_socket(Error &error) const; bool read_response_line(Stream &strm, const Request &req, Response &res); - bool write_request(Stream &strm, const Request &req, bool close_connection, + bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); - bool redirect(const Request &req, Response &res, Error &error); - bool handle_request(Stream &strm, const Request &req, Response &res, + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); std::unique_ptr send_with_content_provider( - const char *method, const char *path, const Headers &headers, + Request &req, + // const char *method, const char *path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, const char *content_type, Error &error); @@ -1238,7 +1254,7 @@ public: Result Options(const char *path); Result Options(const char *path, const Headers &headers); - bool send(const Request &req, Response &res, Error &error); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); size_t is_socket_open() const; @@ -2922,17 +2938,8 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, }); } -template -inline ssize_t write_headers(Stream &strm, const T &info, - const Headers &headers) { +inline ssize_t write_headers(Stream &strm, const Headers &headers) { ssize_t write_len = 0; - for (const auto &x : info.headers) { - if (x.first == "EXCEPTION_WHAT") { continue; } - auto len = - strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); - if (len < 0) { return len; } - write_len += len; - } for (const auto &x : headers) { auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); @@ -3119,7 +3126,7 @@ inline bool write_content_chunked(Stream &strm, } template -inline bool redirect(T &cli, const Request &req, Response &res, +inline bool redirect(T &cli, Request &req, Response &res, const std::string &path, const std::string &location, Error &error) { Request new_req = req; @@ -3136,8 +3143,9 @@ inline bool redirect(T &cli, const Request &req, Response &res, auto ret = cli.send(new_req, new_res, error); if (ret) { - new_res.location = location; + req = new_req; res = new_res; + res.location = location; } return ret; } @@ -3978,7 +3986,27 @@ inline void Response::set_chunked_content_provider( is_chunked_content_provider_ = true; } -// Rstream implementation +// Result implementation +inline bool Result::has_request_header(const char *key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const char *key, + size_t id) const { + return detail::get_header_value(request_headers_, key, id, ""); +} + +template +inline T Result::get_request_header_value(const char *key, size_t id) const { + return detail::get_header_value(request_headers_, key, id, 0); +} + +inline size_t Result::get_request_header_value_count(const char *key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation inline ssize_t Stream::write(const char *ptr) { return write(ptr, strlen(ptr)); } @@ -4473,7 +4501,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, return false; } - if (!detail::write_headers(bstrm, res, Headers())) { return false; } + if (!detail::write_headers(bstrm, res.headers)) { return false; } // Flush buffer auto &data = bstrm.get_buffer(); @@ -4795,26 +4823,26 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) { if (req.method == "POST") { if (dispatch_request_for_content_reader( - req, res, std::move(reader), - post_handlers_for_content_reader_)) { + req, res, std::move(reader), + post_handlers_for_content_reader_)) { return true; } } else if (req.method == "PUT") { if (dispatch_request_for_content_reader( - req, res, std::move(reader), - put_handlers_for_content_reader_)) { + req, res, std::move(reader), + put_handlers_for_content_reader_)) { return true; } } else if (req.method == "PATCH") { if (dispatch_request_for_content_reader( - req, res, std::move(reader), - patch_handlers_for_content_reader_)) { + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { return true; } } else if (req.method == "DELETE") { if (dispatch_request_for_content_reader( - req, res, std::move(reader), - delete_handlers_for_content_reader_)) { + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { return true; } } @@ -5069,7 +5097,7 @@ Server::process_request(Stream &strm, bool close_connection, bool routed = false; try { routed = routing(req, res, strm); - } catch (std::exception & e) { + } catch (std::exception &e) { if (exception_handler_) { exception_handler_(req, res, e); routed = true; @@ -5253,7 +5281,7 @@ inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, return true; } -inline bool ClientImpl::send(const Request &req, Response &res, Error &error) { +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { std::lock_guard request_mutex_guard(request_mutex_); { @@ -5306,6 +5334,12 @@ inline bool ClientImpl::send(const Request &req, Response &res, Error &error) { socket_requests_are_from_thread_ = std::this_thread::get_id(); } + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + auto close_connection = !keep_alive_; auto ret = process_socket(socket_, [&](Stream &strm) { return handle_request(strm, req, res, close_connection, error); @@ -5336,13 +5370,18 @@ inline bool ClientImpl::send(const Request &req, Response &res, Error &error) { } inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { auto res = detail::make_unique(); auto error = Error::Success; auto ret = send(req, *res, error); - return Result{ret ? std::move(res) : nullptr, error}; + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; } -inline bool ClientImpl::handle_request(Stream &strm, const Request &req, +inline bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { if (req.path.empty()) { @@ -5350,12 +5389,16 @@ inline bool ClientImpl::handle_request(Stream &strm, const Request &req, return false; } + auto req_save = req; + bool ret; if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { auto req2 = req; req2.path = "http://" + host_and_port_ + req.path; ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; } else { ret = process_request(strm, req, res, close_connection, error); } @@ -5363,6 +5406,7 @@ inline bool ClientImpl::handle_request(Stream &strm, const Request &req, if (!ret) { return false; } if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; ret = redirect(req, res, error); } @@ -5398,8 +5442,7 @@ inline bool ClientImpl::handle_request(Stream &strm, const Request &req, return ret; } -inline bool ClientImpl::redirect(const Request &req, Response &res, - Error &error) { +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { if (req.redirect_count_ == 0) { error = Error::ExceedRedirectCount; return false; @@ -5476,75 +5519,74 @@ inline bool ClientImpl::write_content_with_provider(Stream &strm, } } // namespace httplib -inline bool ClientImpl::write_request(Stream &strm, const Request &req, +inline bool ClientImpl::write_request(Stream &strm, Request &req, bool close_connection, Error &error) { // Prepare additional headers - Headers headers; - if (close_connection) { headers.emplace("Connection", "close"); } + if (close_connection) { req.headers.emplace("Connection", "close"); } if (!req.has_header("Host")) { if (is_ssl()) { if (port_ == 443) { - headers.emplace("Host", host_); + req.headers.emplace("Host", host_); } else { - headers.emplace("Host", host_and_port_); + req.headers.emplace("Host", host_and_port_); } } else { if (port_ == 80) { - headers.emplace("Host", host_); + req.headers.emplace("Host", host_); } else { - headers.emplace("Host", host_and_port_); + req.headers.emplace("Host", host_and_port_); } } } - if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + if (!req.has_header("Accept")) { req.headers.emplace("Accept", "*/*"); } if (!req.has_header("User-Agent")) { - headers.emplace("User-Agent", "cpp-httplib/0.7"); + req.headers.emplace("User-Agent", "cpp-httplib/0.7"); } if (req.body.empty()) { if (req.content_provider_) { if (!req.is_chunked_content_provider_) { auto length = std::to_string(req.content_length_); - headers.emplace("Content-Length", length); + req.headers.emplace("Content-Length", length); } } else { if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { - headers.emplace("Content-Length", "0"); + req.headers.emplace("Content-Length", "0"); } } } else { if (!req.has_header("Content-Type")) { - headers.emplace("Content-Type", "text/plain"); + req.headers.emplace("Content-Type", "text/plain"); } if (!req.has_header("Content-Length")) { auto length = std::to_string(req.body.size()); - headers.emplace("Content-Length", length); + req.headers.emplace("Content-Length", length); } } if (!basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( + req.headers.insert(make_basic_authentication_header( basic_auth_username_, basic_auth_password_, false)); } if (!proxy_basic_auth_username_.empty() && !proxy_basic_auth_password_.empty()) { - headers.insert(make_basic_authentication_header( + req.headers.insert(make_basic_authentication_header( proxy_basic_auth_username_, proxy_basic_auth_password_, true)); } if (!bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( + req.headers.insert(make_bearer_token_authentication_header( bearer_token_auth_token_, false)); } if (!proxy_bearer_token_auth_token_.empty()) { - headers.insert(make_bearer_token_authentication_header( + req.headers.insert(make_bearer_token_authentication_header( proxy_bearer_token_auth_token_, true)); } @@ -5555,7 +5597,7 @@ inline bool ClientImpl::write_request(Stream &strm, const Request &req, const auto &path = detail::encode_url(req.path); bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); - detail::write_headers(bstrm, req, headers); + detail::write_headers(bstrm, req.headers); // Flush buffer auto &data = bstrm.get_buffer(); @@ -5576,16 +5618,16 @@ inline bool ClientImpl::write_request(Stream &strm, const Request &req, } inline std::unique_ptr ClientImpl::send_with_content_provider( - const char *method, const char *path, const Headers &headers, + Request &req, + // const char *method, const char *path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, const char *content_type, Error &error) { - Request req; - req.method = method; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.path = path; + // Request req; + // req.method = method; + // req.headers = headers; + // req.path = path; if (content_type) { req.headers.emplace("Content-Type", content_type); } @@ -5667,14 +5709,23 @@ inline Result ClientImpl::send_with_content_provider( const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, const char *content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + auto error = Error::Success; + auto res = send_with_content_provider( - method, path, headers, body, content_length, std::move(content_provider), + req, + // method, path, headers, + body, content_length, std::move(content_provider), std::move(content_provider_without_length), content_type, error); - return Result{std::move(res), error}; + + return Result{std::move(res), error, std::move(req.headers)}; } -inline bool ClientImpl::process_request(Stream &strm, const Request &req, +inline bool ClientImpl::process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { // Send request @@ -5687,8 +5738,8 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, return false; } - if (req.response_handler_) { - if (!req.response_handler_(res)) { + if (req.response_handler) { + if (!req.response_handler(res)) { error = Error::Canceled; return false; } @@ -5697,10 +5748,10 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, // Body if ((res.status != 204) && req.method != "HEAD" && req.method != "CONNECT") { auto out = - req.content_receiver_ + req.content_receiver ? static_cast( [&](const char *buf, size_t n, uint64_t off, uint64_t len) { - auto ret = req.content_receiver_(buf, n, off, len); + auto ret = req.content_receiver(buf, n, off, len); if (!ret) { error = Error::Canceled; } return ret; }) @@ -5715,8 +5766,8 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req, }); auto progress = [&](uint64_t current, uint64_t total) { - if (!req.progress_) { return true; } - auto ret = req.progress_(current, total); + if (!req.progress) { return true; } + auto ret = req.progress(current, total); if (!ret) { error = Error::Canceled; } return ret; }; @@ -5778,11 +5829,10 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers, Request req; req.method = "GET"; req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.progress_ = std::move(progress); + req.headers = headers; + req.progress = std::move(progress); - return send(req); + return send_(std::move(req)); } inline Result ClientImpl::Get(const char *path, @@ -5838,17 +5888,16 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers, Request req; req.method = "GET"; req.path = path; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); - req.response_handler_ = std::move(response_handler); - req.content_receiver_ = + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = [content_receiver](const char *data, size_t data_length, uint64_t /*offset*/, uint64_t /*total_length*/) { return content_receiver(data, data_length); }; - req.progress_ = std::move(progress); + req.progress = std::move(progress); - return send(req); + return send_(std::move(req)); } inline Result ClientImpl::Get(const char *path, const Params ¶ms, @@ -5887,11 +5936,10 @@ inline Result ClientImpl::Head(const char *path) { inline Result ClientImpl::Head(const char *path, const Headers &headers) { Request req; req.method = "HEAD"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); + req.headers = headers; req.path = path; - return send(req); + return send_(std::move(req)); } inline Result ClientImpl::Post(const char *path) { @@ -6151,14 +6199,13 @@ inline Result ClientImpl::Delete(const char *path, const Headers &headers, const char *content_type) { Request req; req.method = "DELETE"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); + req.headers = headers; req.path = path; if (content_type) { req.headers.emplace("Content-Type", content_type); } req.body.assign(body, content_length); - return send(req); + return send_(std::move(req)); } inline Result ClientImpl::Delete(const char *path, const std::string &body, @@ -6179,11 +6226,10 @@ inline Result ClientImpl::Options(const char *path) { inline Result ClientImpl::Options(const char *path, const Headers &headers) { Request req; req.method = "OPTIONS"; - req.headers = default_headers_; - req.headers.insert(headers.begin(), headers.end()); + req.headers = headers; req.path = path; - return send(req); + return send_(std::move(req)); } inline size_t ClientImpl::is_socket_open() const { @@ -7303,7 +7349,7 @@ inline Result Client::Options(const char *path, const Headers &headers) { return cli_->Options(path, headers); } -inline bool Client::send(const Request &req, Response &res, Error &error) { +inline bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } diff --git a/test/test.cc b/test/test.cc index 6325cdc..012b390 100644 --- a/test/test.cc +++ b/test/test.cc @@ -5,8 +5,8 @@ #include #include #include -#include #include +#include #define SERVER_CERT_FILE "./cert.pem" #define SERVER_CERT2_FILE "./cert2.pem" @@ -549,12 +549,11 @@ TEST(ConnectionErrorTest, InvalidHost2) { TEST(ConnectionErrorTest, InvalidPort) { auto host = "localhost"; + auto port = 44380; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT - auto port = 44380; SSLClient cli(host, port); #else - auto port = 8080; Client cli(host, port); #endif cli.set_connection_timeout(2); @@ -982,11 +981,12 @@ TEST(ErrorHandlerTest, ContentLength) { TEST(ExceptionHandlerTest, ContentLength) { Server svr; - svr.set_exception_handler([](const Request & /*req*/, Response &res, std::exception & /*e*/) { - res.status = 500; - res.set_content("abcdefghijklmnopqrstuvwxyz", - "text/html"); // <= Content-Length still 13 - }); + svr.set_exception_handler( + [](const Request & /*req*/, Response &res, std::exception & /*e*/) { + res.status = 500; + res.set_content("abcdefghijklmnopqrstuvwxyz", + "text/html"); // <= Content-Length still 13 + }); svr.Get("/hi", [](const Request & /*req*/, Response &res) { res.set_content("Hello World!\n", "text/plain"); @@ -2614,6 +2614,24 @@ TEST_F(ServerTest, PutLargeFileWithGzip) { EXPECT_EQ(LARGE_DATA, res->body); } +TEST_F(ServerTest, PutLargeFileWithGzip2) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + Client cli("https://localhost:1234"); + cli.enable_server_certificate_verification(false); +#else + Client cli("http://localhost:1234"); +#endif + cli.set_compress(true); + + auto res = cli.Put("/put-large", LARGE_DATA, "text/plain"); + + ASSERT_TRUE(res); + EXPECT_EQ(200, res->status); + EXPECT_EQ(LARGE_DATA, res->body); + EXPECT_EQ(101942u, res.get_request_header_value("Content-Length")); + EXPECT_EQ("gzip", res.get_request_header_value("Content-Encoding")); +} + TEST_F(ServerTest, PutContentWithDeflate) { cli_.set_compress(false); Headers headers; @@ -3405,7 +3423,8 @@ TEST(ExceptionTest, ThrowExceptionInHandler) { auto res = cli.Get("/hi"); ASSERT_TRUE(res); EXPECT_EQ(500, res->status); - ASSERT_FALSE(res->has_header("EXCEPTION_WHAT")); + ASSERT_TRUE(res->has_header("EXCEPTION_WHAT")); + EXPECT_EQ("exception...", res->get_header_value("EXCEPTION_WHAT")); svr.stop(); listen_thread.join(); @@ -3963,7 +3982,6 @@ TEST(NoSSLSupport, SimpleInterface) { } #endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT TEST(InvalidScheme, SimpleInterface) { ASSERT_ANY_THROW(Client cli("scheme://yahoo.com")); } @@ -3973,6 +3991,19 @@ TEST(NoScheme, SimpleInterface) { ASSERT_TRUE(cli.is_valid()); } +TEST(SendAPI, SimpleInterface) { + Client cli("http://yahoo.com"); + + Request req; + req.method = "GET"; + req.path = "/"; + auto res = cli.send(req); + + ASSERT_TRUE(res); + EXPECT_EQ(301, res->status); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT TEST(YahooRedirectTest2, SimpleInterface) { Client cli("http://yahoo.com");