From 7cd25fbd63745158d79c37871f32f510ccf131cd Mon Sep 17 00:00:00 2001 From: yhirose Date: Tue, 16 Jun 2020 17:46:23 -0400 Subject: [PATCH] Fix #499 --- README.md | 26 +-- httplib.h | 450 ++++++++++++++++++++++----------------------- test/test.cc | 92 +++++---- test/test_proxy.cc | 72 +++----- 4 files changed, 304 insertions(+), 336 deletions(-) diff --git a/README.md b/README.md index 48350db..9dc0c76 100644 --- a/README.md +++ b/README.md @@ -482,29 +482,15 @@ httplib::make_range_header({{0, 0}, {-1, 1}}) // 'Range: bytes=0-0, -1' ### Keep-Alive connection ```cpp -cli.set_keep_alive_max_count(2); // Default is 5 +httplib::Client cli("localhost", 1234); -std::vector requests; -Get(requests, "/get-request1"); -Get(requests, "/get-request2"); -Post(requests, "/post-request1", "text", "text/plain"); -Post(requests, "/post-request2", "text", "text/plain"); +cli.Get("/hello"); // with "Connection: close" -const size_t DATA_CHUNK_SIZE = 4; -std::string data("abcdefg"); -Post(requests, "/post-request-with-content-provider", - data.size(), - [&](size_t offset, size_t length, DataSink &sink){ - sink.write(&data[offset], std::min(length, DATA_CHUNK_SIZE)); - }, - "text/plain"); +cli.set_keep_alive(true); +cli.Get("/world"); -std::vector responses; -if (cli.send(requests, responses)) { - for (const auto& res: responses) { - ... - } -} +cli.set_keep_alive(false); +cli.Get("/last-request"); // with "Connection: close" ``` ### Redirect diff --git a/httplib.h b/httplib.h index fdcbed1..6a97689 100644 --- a/httplib.h +++ b/httplib.h @@ -188,6 +188,7 @@ using socket_t = int; #include #include #include +#include #include #include #include @@ -593,10 +594,11 @@ public: std::function new_task_queue; protected: - bool process_request(Stream &strm, bool last_connection, - bool &connection_close, + bool process_request(Stream &strm, bool close_connection, + bool &connection_closed, const std::function &setup_request); + std::atomic svr_sock_; size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; @@ -624,7 +626,7 @@ private: HandlersForContentReader &handlers); bool parse_request_line(const char *s, Request &req); - bool write_response(Stream &strm, bool last_connection, const Request &req, + bool write_response(Stream &strm, bool close_connection, const Request &req, Response &res); bool write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, @@ -643,7 +645,6 @@ private: virtual bool process_and_close_socket(socket_t sock); std::atomic is_running_; - std::atomic svr_sock_; std::vector> base_dirs_; std::map file_extension_and_mimetype_map_; Handler file_request_handler_; @@ -797,9 +798,6 @@ public: bool send(const Request &req, Response &res); - bool send(const std::vector &requests, - std::vector &responses); - size_t is_socket_open() const; void stop(); @@ -809,13 +807,12 @@ public: void set_read_timeout(time_t sec, time_t usec = 0); void set_write_timeout(time_t sec, time_t usec = 0); - void set_keep_alive_max_count(size_t count); - void set_basic_auth(const char *username, const char *password); #ifdef CPPHTTPLIB_OPENSSL_SUPPORT void set_digest_auth(const char *username, const char *password); #endif + void set_keep_alive(bool on); void set_follow_location(bool on); void set_compress(bool on); @@ -846,7 +843,7 @@ protected: virtual void close_socket(Socket &socket, bool process_socket_ret); bool process_request(Stream &strm, const Request &req, Response &res, - bool &connection_close); + bool close_connection); // Socket endoint information const std::string host_; @@ -869,8 +866,6 @@ protected: time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; - size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; - std::string basic_auth_username_; std::string basic_auth_password_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -878,6 +873,7 @@ protected: std::string digest_auth_password_; #endif + bool keep_alive_ = false; bool follow_location_ = false; bool compress_ = false; @@ -905,13 +901,13 @@ protected: read_timeout_usec_ = rhs.read_timeout_usec_; write_timeout_sec_ = rhs.write_timeout_sec_; write_timeout_usec_ = rhs.write_timeout_usec_; - keep_alive_max_count_ = rhs.keep_alive_max_count_; basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT digest_auth_username_ = rhs.digest_auth_username_; digest_auth_password_ = rhs.digest_auth_password_; #endif + keep_alive_ = rhs.keep_alive_; follow_location_ = rhs.follow_location_; compress_ = rhs.compress_; decompress_ = rhs.decompress_; @@ -930,22 +926,18 @@ protected: private: socket_t create_client_socket() const; bool read_response_line(Stream &strm, Response &res); - bool write_request(Stream &strm, const Request &req); + bool write_request(Stream &strm, const Request &req, bool close_connection); bool redirect(const Request &req, Response &res); bool handle_request(Stream &strm, const Request &req, Response &res, - bool &connection_close); + bool close_connection); std::shared_ptr send_with_content_provider( const char *method, const char *path, const Headers &headers, const std::string &body, size_t content_length, ContentProvider content_provider, const char *content_type); - virtual bool - process_socket(Socket &socket, size_t request_count, - std::function - callback); - + virtual bool process_socket(Socket &socket, + std::function callback); virtual bool is_ssl() const; }; @@ -1045,15 +1037,13 @@ public: private: bool create_and_connect_socket(Socket &socket) override; - bool connect_with_proxy(Socket &sock, bool &error); void close_socket(Socket &socket, bool process_socket_ret) override; - bool process_socket(Socket &socket, size_t request_count, - std::function - callback) override; + bool process_socket(Socket &socket, + std::function callback) override; bool is_ssl() const override; + bool connect_with_proxy(Socket &sock, Response &res, bool &success); bool initialize_ssl(Socket &socket); bool verify_host(X509 *server_cert) const; @@ -1070,6 +1060,8 @@ private: X509_STORE *ca_cert_store_ = nullptr; bool server_certificate_verification_ = false; long verify_result_ = 0; + + friend class Client; }; #endif @@ -1301,11 +1293,6 @@ public: bool send(const Request &req, Response &res) { return cli_->send(req, res); } - bool send(const std::vector &requests, - std::vector &responses) { - return cli_->send(requests, responses); - } - bool is_socket_open() { return cli_->is_socket_open(); } void stop() { cli_->stop(); } @@ -1320,11 +1307,6 @@ public: return *this; } - Client2 &set_keep_alive_max_count(size_t count) { - cli_->set_keep_alive_max_count(count); - return *this; - } - Client2 &set_basic_auth(const char *username, const char *password) { cli_->set_basic_auth(username, password); return *this; @@ -1337,6 +1319,11 @@ public: } #endif + Client2 &set_keep_alive(bool on) { + cli_->set_keep_alive(on); + return *this; + } + Client2 &set_follow_location(bool on) { cli_->set_follow_location(on); return *this; @@ -1863,49 +1850,75 @@ private: size_t position = 0; }; -template -inline bool process_socket_core(bool is_client_request, socket_t sock, - size_t keep_alive_max_count, T callback) { - assert(keep_alive_max_count > 0); - - auto ret = false; - - if (keep_alive_max_count > 1) { - auto count = keep_alive_max_count; - while (count > 0 && - (is_client_request || - select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, - CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { - auto last_connection = count == 1; - auto connection_close = false; - - ret = callback(last_connection, connection_close); - if (!ret || connection_close) { break; } - - count--; +inline bool keep_alive(socket_t sock, std::function is_shutting_down) { + using namespace std::chrono; + auto start = steady_clock::now(); + while (true) { + auto val = select_read(sock, 0, 10000); + if (is_shutting_down && is_shutting_down()) { + return false; + } else if (val < 0) { + return false; + } else if (val == 0) { + auto current = steady_clock::now(); + auto sec = duration_cast(current - start); + if (sec.count() > CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND) { + return false; + } else if (sec.count() == CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND) { + auto usec = duration_cast(current - start); + if (usec.count() > CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) { + return false; + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + return true; } - } else { // keep_alive_max_count is 0 or 1 - auto dummy_connection_close = false; - ret = callback(true, dummy_connection_close); } +} +template +inline bool process_server_socket_core(socket_t sock, + size_t keep_alive_max_count, + T is_shutting_down, U callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (count > 0 && keep_alive(sock, is_shutting_down)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { break; } + count--; + } return ret; } -template -inline bool process_socket(bool is_client_request, socket_t sock, - size_t keep_alive_max_count, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { - return process_socket_core( - is_client_request, sock, keep_alive_max_count, - [&](bool last_connection, bool connection_close) { +template +inline bool +process_server_socket(socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + T is_shutting_down, U callback) { + return process_server_socket_core( + sock, keep_alive_max_count, is_shutting_down, + [&](bool close_connection, bool connection_closed) { SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); - return callback(strm, last_connection, connection_close); + return callback(strm, close_connection, connection_closed); }); } +template +inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); +} + inline int shutdown_socket(socket_t sock) { #ifdef _WIN32 return shutdown(sock, SD_BOTH); @@ -2545,7 +2558,6 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, } if (!ret) { status = exceed_payload_max_length ? 413 : 400; } - return ret; } @@ -2582,8 +2594,9 @@ inline bool write_data(Stream &strm, const char *d, size_t l) { return true; } +template inline ssize_t write_content(Stream &strm, ContentProvider content_provider, - size_t offset, size_t length) { + size_t offset, size_t length, T is_shutting_down) { size_t begin_offset = offset; size_t end_offset = offset + length; @@ -2598,7 +2611,7 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider, }; data_sink.is_writable = [&](void) { return ok && strm.is_writable(); }; - while (ok && offset < end_offset) { + while (ok && offset < end_offset && !is_shutting_down()) { if (!content_provider(offset, end_offset - offset, data_sink)) { return -1; } @@ -3110,16 +3123,19 @@ get_multipart_ranges_data_length(const Request &req, Response &res, return data_length; } +template inline bool write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, const std::string &boundary, - const std::string &content_type) { + const std::string &content_type, + T is_shutting_down) { return process_multipart_ranges_data( req, res, boundary, content_type, [&](const std::string &token) { strm.write(token); }, [&](const char *token) { strm.write(token); }, [&](size_t offset, size_t length) { - return write_content(strm, res.content_provider_, offset, length) >= 0; + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down) >= 0; }); } @@ -3576,7 +3592,7 @@ inline const std::string &BufferStream::get_buffer() const { return buffer; } } // namespace detail // HTTP server implementation -inline Server::Server() : is_running_(false), svr_sock_(INVALID_SOCKET) { +inline Server::Server() : svr_sock_(INVALID_SOCKET), is_running_(false) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); #endif @@ -3758,7 +3774,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) { return false; } -inline bool Server::write_response(Stream &strm, bool last_connection, +inline bool Server::write_response(Stream &strm, bool close_connection, const Request &req, Response &res) { assert(res.status != -1); @@ -3773,11 +3789,11 @@ inline bool Server::write_response(Stream &strm, bool last_connection, } // Headers - if (last_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close") { res.set_header("Connection", "close"); } - if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + if (!close_connection && req.get_header_value("Connection") == "Keep-Alive") { res.set_header("Connection", "Keep-Alive"); } @@ -3891,10 +3907,14 @@ inline bool Server::write_content_with_provider(Stream &strm, const Request &req, Response &res, const std::string &boundary, const std::string &content_type) { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + if (res.content_length_) { if (req.ranges.empty()) { if (detail::write_content(strm, res.content_provider_, 0, - res.content_length_) < 0) { + res.content_length_, is_shutting_down) < 0) { return false; } } else if (req.ranges.size() == 1) { @@ -3902,20 +3922,17 @@ Server::write_content_with_provider(Stream &strm, const Request &req, detail::get_range_offset_and_length(req, res.content_length_, 0); auto offset = offsets.first; auto length = offsets.second; - if (detail::write_content(strm, res.content_provider_, offset, length) < - 0) { + if (detail::write_content(strm, res.content_provider_, offset, length, + is_shutting_down) < 0) { return false; } } else { - if (!detail::write_multipart_ranges_data(strm, req, res, boundary, - content_type)) { + if (!detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, is_shutting_down)) { return false; } } } else { - auto is_shutting_down = [this]() { - return this->svr_sock_ == INVALID_SOCKET; - }; if (detail::write_content_chunked(strm, res.content_provider_, is_shutting_down) < 0) { return false; @@ -4241,8 +4258,8 @@ inline bool Server::dispatch_request_for_content_reader( } inline bool -Server::process_request(Stream &strm, bool last_connection, - bool &connection_close, +Server::process_request(Stream &strm, bool close_connection, + bool &connection_closed, const std::function &setup_request) { std::array buf{}; @@ -4261,23 +4278,23 @@ Server::process_request(Stream &strm, bool last_connection, Headers dummy; detail::read_headers(strm, dummy); res.status = 414; - return write_response(strm, last_connection, req, res); + return write_response(strm, close_connection, req, res); } // Request line and headers if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { res.status = 400; - return write_response(strm, last_connection, req, res); + return write_response(strm, close_connection, req, res); } if (req.get_header_value("Connection") == "close") { - connection_close = true; + connection_closed = true; } if (req.version == "HTTP/1.0" && req.get_header_value("Connection") != "Keep-Alive") { - connection_close = true; + connection_closed = true; } strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); @@ -4304,7 +4321,7 @@ Server::process_request(Stream &strm, bool last_connection, strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, detail::status_message(status)); break; - default: return write_response(strm, last_connection, req, res); + default: return write_response(strm, close_connection, req, res); } } @@ -4315,20 +4332,23 @@ Server::process_request(Stream &strm, bool last_connection, if (res.status == -1) { res.status = 404; } } - return write_response(strm, last_connection, req, res); + return write_response(strm, close_connection, req, res); } inline bool Server::is_valid() const { return true; } inline bool Server::process_and_close_socket(socket_t sock) { - auto ret = detail::process_socket( - false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + auto ret = detail::process_server_socket( + sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this](Stream &strm, bool last_connection, bool &connection_close) { - return process_request(strm, last_connection, connection_close, + [this]() { return this->svr_sock_ == INVALID_SOCKET; }, + [this](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, nullptr); }); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + detail::shutdown_socket(sock); detail::close_socket(sock); return ret; } @@ -4347,12 +4367,7 @@ inline Client::Client(const std::string &host, int port, host_and_port_(host_ + ":" + std::to_string(port_)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} -inline Client::~Client() { - assert(socket_.sock == INVALID_SOCKET); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - assert(socket_.ssl == nullptr); -#endif -} +inline Client::~Client() { stop(); } inline bool Client::is_valid() const { return true; } @@ -4402,63 +4417,49 @@ inline bool Client::read_response_line(Stream &strm, Response &res) { inline bool Client::send(const Request &req, Response &res) { std::lock_guard request_mutex_guard(request_mutex_); - auto need_new_socket = !is_socket_open(); - if (need_new_socket) { + { std::lock_guard guard(socket_mutex_); - if (!create_and_connect_socket(socket_)) { return false; } + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::select_write(socket_.sock, 0, 0) > 0; + if (!is_alive) { close_socket(socket_, false); } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty()) { + bool success = false; + if (!scli.connect_with_proxy(socket_, res, success)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_)) { return false; } + } +#endif + } } - auto ret = process_socket( - socket_, 1, - [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { - return handle_request(strm, req, res, connection_close); - }); + auto close_connection = !keep_alive_; - if (need_new_socket) { - std::lock_guard guard(socket_mutex_); - if (socket_.is_open()) { close_socket(socket_, ret); } - } + auto ret = process_socket(socket_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection); + }); + + if (close_connection) { stop(); } return ret; } -inline bool Client::send(const std::vector &requests, - std::vector &responses) { - std::lock_guard request_mutex_guard(request_mutex_); - - size_t i = 0; - while (i < requests.size()) { - { - std::lock_guard guard(socket_mutex_); - if (!create_and_connect_socket(socket_)) { return false; } - } - - auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_); - - auto ret = process_socket( - socket_, request_count, - [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { - auto &req = requests[i++]; - auto res = Response(); - auto ret = handle_request(strm, req, res, connection_close); - if (ret) { responses.emplace_back(std::move(res)); } - return ret; - }); - - { - std::lock_guard guard(socket_mutex_); - if (socket_.is_open()) { close_socket(socket_, ret); } - } - - if (!ret) { return false; } - } - - return true; -} - inline bool Client::handle_request(Stream &strm, const Request &req, - Response &res, bool &connection_close) { + Response &res, bool close_connection) { if (req.path.empty()) { return false; } bool ret; @@ -4466,9 +4467,9 @@ inline bool Client::handle_request(Stream &strm, const Request &req, if (!is_ssl() && !proxy_host_.empty()) { auto req2 = req; req2.path = "http://" + host_and_port_ + req.path; - ret = process_request(strm, req2, res, connection_close); + ret = process_request(strm, req2, res, close_connection); } else { - ret = process_request(strm, req, res, connection_close); + ret = process_request(strm, req, res, close_connection); } if (!ret) { return false; } @@ -4558,7 +4559,8 @@ inline bool Client::redirect(const Request &req, Response &res) { } } -inline bool Client::write_request(Stream &strm, const Request &req) { +inline bool Client::write_request(Stream &strm, const Request &req, + bool close_connection) { detail::BufferStream bstrm; // Request line @@ -4568,6 +4570,8 @@ inline bool Client::write_request(Stream &strm, const Request &req) { // Additonal headers Headers headers; + if (close_connection) { headers.emplace("Connection", "close"); } + if (!req.has_header("Host")) { if (is_ssl()) { if (port_ == 443) { @@ -4710,9 +4714,9 @@ inline std::shared_ptr Client::send_with_content_provider( } inline bool Client::process_request(Stream &strm, const Request &req, - Response &res, bool &connection_close) { + Response &res, bool close_connection) { // Send request - if (!write_request(strm, req)) { return false; } + if (!write_request(strm, req, close_connection)) { return false; } // Receive response and headers if (!read_response_line(strm, res) || @@ -4720,11 +4724,6 @@ inline bool Client::process_request(Stream &strm, const Request &req, return false; } - if (res.get_header_value("Connection") == "close" || - res.version == "HTTP/1.0") { - connection_close = true; - } - if (req.response_handler) { if (!req.response_handler(res)) { return false; } } @@ -4749,20 +4748,22 @@ inline bool Client::process_request(Stream &strm, const Request &req, } } + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + stop(); + } + // Log if (logger_) { logger_(req, res); } return true; } -inline bool -Client::process_socket(Socket &socket, size_t request_count, - std::function - callback) { - return detail::process_socket( - true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, callback); +inline bool Client::process_socket(Socket &socket, + std::function callback) { + return detail::process_client_socket(socket.sock, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, callback); } inline bool Client::is_ssl() const { return false; } @@ -5066,9 +5067,9 @@ inline void Client::stop() { std::lock_guard guard(socket_mutex_); if (socket_.is_open()) { detail::shutdown_socket(socket_.sock); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); close_socket(socket_, true); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); } } @@ -5091,10 +5092,6 @@ inline void Client::set_write_timeout(time_t sec, time_t usec) { write_timeout_usec_ = usec; } -inline void Client::set_keep_alive_max_count(size_t count) { - keep_alive_max_count_ = count; -} - inline void Client::set_basic_auth(const char *username, const char *password) { basic_auth_username_ = username; basic_auth_password_ = password; @@ -5108,6 +5105,8 @@ inline void Client::set_digest_auth(const char *username, } #endif +inline void Client::set_keep_alive(bool on) { keep_alive_ = on; } + inline void Client::set_follow_location(bool on) { follow_location_ = on; } inline void Client::set_compress(bool on) { compress_ = on; } @@ -5181,19 +5180,29 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, template inline bool -process_socket_ssl(SSL *ssl, bool is_client_request, socket_t sock, - size_t keep_alive_max_count, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { - return process_socket_core( - is_client_request, sock, keep_alive_max_count, - [&](bool last_connection, bool connection_close) { +process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec, + std::function is_shutting_down, T callback) { + return process_server_socket_core( + sock, keep_alive_max_count, is_shutting_down, + [&](bool close_connection, bool connection_closed) { SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); - return callback(strm, last_connection, connection_close); + return callback(strm, close_connection, connection_closed); }); } +template +inline bool +process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); +} + #if OPENSSL_VERSION_NUMBER < 0x10100000L static std::shared_ptr> openSSL_locks_; @@ -5365,12 +5374,13 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) { [](SSL * /*ssl*/) { return true; }); if (ssl) { - auto ret = detail::process_socket_ssl( - ssl, false, sock, keep_alive_max_count_, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [this, ssl](Stream &strm, bool last_connection, - bool &connection_close) { - return process_request(strm, last_connection, connection_close, + auto ret = detail::process_server_socket_ssl( + ssl, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, + [this]() { return this->svr_sock_ == INVALID_SOCKET; }, + [this, ssl](Stream &strm, bool close_connection, + bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, [&](Request &req) { req.ssl = ssl; }); }); @@ -5455,49 +5465,36 @@ inline long SSLClient::get_openssl_verify_result() const { inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } inline bool SSLClient::create_and_connect_socket(Socket &socket) { - if (is_valid() && Client::create_and_connect_socket(socket) && - initialize_ssl(socket)) { - if (!proxy_host_.empty()) { - bool error; - if (!connect_with_proxy(socket, error)) { return error; } - } - return true; - } - return false; + return is_valid() && Client::create_and_connect_socket(socket); } -inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) { - error = true; - Response res; +inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, + bool &success) { + success = true; + Response res2; - if (!detail::process_socket_core( - true, socket.sock, 1, - [&](bool /*last_connection*/, bool &connection_close) { - detail::SocketStream strm(socket.sock, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { Request req2; req2.method = "CONNECT"; req2.path = host_and_port_; - return process_request(strm, req2, res, connection_close); + return process_request(strm, req2, res2, false); })) { close_socket(socket, true); - error = false; + success = false; return false; } - if (res.status == 407) { + if (res2.status == 407) { if (!proxy_digest_auth_username_.empty() && !proxy_digest_auth_password_.empty()) { std::map auth; - if (parse_www_authenticate(res, auth, true)) { + if (parse_www_authenticate(res2, auth, true)) { Response res3; - if (!detail::process_socket_core( - true, socket.sock, 1, - [&](bool /*last_connection*/, bool &connection_close) { - detail::SocketStream strm( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { Request req3; req3.method = "CONNECT"; req3.path = host_and_port_; @@ -5505,14 +5502,15 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) { req3, auth, 1, random_string(10), proxy_digest_auth_username_, proxy_digest_auth_password_, true)); - return process_request(strm, req3, res3, connection_close); + return process_request(strm, req3, res3, false); })) { close_socket(socket, true); - error = false; + success = false; return false; } } } else { + res = res2; return false; } } @@ -5583,17 +5581,12 @@ inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { } inline bool -SSLClient::process_socket(Socket &socket, size_t request_count, - std::function - callback) { +SSLClient::process_socket(Socket &socket, + std::function callback) { assert(socket.ssl); - return detail::process_socket_ssl( - socket.ssl, true, socket.sock, request_count, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, - [&](Stream &strm, bool last_connection, bool &connection_close) { - return callback(strm, last_connection, connection_close); - }); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, callback); } inline bool SSLClient::is_ssl() const { return true; } @@ -5678,7 +5671,6 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { } GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); - return ret; } diff --git a/test/test.cc b/test/test.cc index 94d3d34..c4b2b04 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1136,6 +1136,10 @@ protected: EXPECT_EQ(req.get_param_value("key"), "value"); EXPECT_EQ(req.body, "content"); }) + .Get("/last-request", + [&](const Request & req, Response &/*res*/) { + EXPECT_EQ("close", req.get_header_value("Connection")); + }) #ifdef CPPHTTPLIB_ZLIB_SUPPORT .Get("/gzip", [&](const Request & /*req*/, Response &res) { @@ -2127,42 +2131,48 @@ TEST_F(ServerTest, HTTP2Magic) { } TEST_F(ServerTest, KeepAlive) { - cli_.set_keep_alive_max_count(4); + auto res = cli_.Get("/hi"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_EQ("Hello World!", res->body); - std::vector requests; - Get(requests, "/hi"); - Get(requests, "/hi"); - Get(requests, "/hi"); - Get(requests, "/not-exist"); - Post(requests, "/empty", "", "text/plain"); - Post( - requests, "/empty", 0, - [&](size_t, size_t, httplib::DataSink &) { return true; }, "text/plain"); + res = cli_.Get("/hi"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_EQ("Hello World!", res->body); - std::vector responses; - auto ret = cli_.send(requests, responses); + res = cli_.Get("/hi"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_EQ("Hello World!", res->body); - ASSERT_TRUE(ret == true); - ASSERT_TRUE(requests.size() == responses.size()); + res = cli_.Get("/not-exist"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(404, res->status); - for (size_t i = 0; i < 3; i++) { - auto &res = responses[i]; - EXPECT_EQ(200, res.status); - EXPECT_EQ("text/plain", res.get_header_value("Content-Type")); - EXPECT_EQ("Hello World!", res.body); - } + res = cli_.Post("/empty", "", "text/plain"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_EQ("empty", res->body); + EXPECT_EQ("close", res->get_header_value("Connection")); - { - auto &res = responses[3]; - EXPECT_EQ(404, res.status); - } + res = cli_.Post( + "/empty", 0, [&](size_t, size_t, httplib::DataSink &) { return true; }, + "text/plain"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ("text/plain", res->get_header_value("Content-Type")); + EXPECT_EQ("empty", res->body); - for (size_t i = 4; i < 6; i++) { - auto &res = responses[i]; - EXPECT_EQ(200, res.status); - EXPECT_EQ("text/plain", res.get_header_value("Content-Type")); - EXPECT_EQ("empty", res.body); - } + cli_.set_keep_alive(false); + res = cli_.Get("/last-request"); + ASSERT_TRUE(res != nullptr); + EXPECT_EQ(200, res->status); + EXPECT_EQ("close", res->get_header_value("Connection")); } #ifdef CPPHTTPLIB_ZLIB_SUPPORT @@ -2310,10 +2320,8 @@ static bool send_request(time_t read_timeout_sec, const std::string &req, if (client_sock == INVALID_SOCKET) { return false; } - auto ret = detail::process_socket( - true, client_sock, 1, read_timeout_sec, 0, 0, 0, - [&](Stream &strm, bool /*last_connection*/, bool & - /*connection_close*/) -> bool { + auto ret = detail::process_client_socket( + client_sock, read_timeout_sec, 0, 0, 0, [&](Stream &strm) { if (req.size() != static_cast(strm.write(req.data(), req.size()))) { return false; @@ -2515,8 +2523,7 @@ TEST(ServerStopTest, StopServerWithChunkedTransmission) { } Client client(HOST, PORT); - const Headers headers = {{"Accept", "text/event-stream"}, - {"Connection", "Keep-Alive"}}; + const Headers headers = {{"Accept", "text/event-stream"}}; auto get_thread = std::thread([&client, &headers]() { std::shared_ptr res = client.Get( @@ -2742,19 +2749,24 @@ TEST(SSLClientTest, ServerNameIndication) { ASSERT_EQ(200, res->status); } -TEST(SSLClientTest, ServerCertificateVerification) { +TEST(SSLClientTest, ServerCertificateVerification1) { SSLClient cli("google.com"); - auto res = cli.Get("/"); ASSERT_TRUE(res != nullptr); ASSERT_EQ(301, res->status); +} +TEST(SSLClientTest, ServerCertificateVerification2) { + SSLClient cli("google.com"); cli.enable_server_certificate_verification(true); - res = cli.Get("/"); + auto res = cli.Get("/"); ASSERT_TRUE(res == nullptr); +} +TEST(SSLClientTest, ServerCertificateVerification3) { + SSLClient cli("google.com"); cli.set_ca_cert_path(CA_CERT_FILE); - res = cli.Get("/"); + auto res = cli.Get("/"); ASSERT_TRUE(res != nullptr); ASSERT_EQ(301, res->status); } diff --git a/test/test_proxy.cc b/test/test_proxy.cc index 1a36b77..9dd658e 100644 --- a/test/test_proxy.cc +++ b/test/test_proxy.cc @@ -222,66 +222,45 @@ void KeepAliveTest(Client& cli, bool basic) { #endif } - cli.set_keep_alive_max_count(4); cli.set_follow_location(true); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT cli.set_digest_auth("hello", "world"); +#endif - std::vector requests; + { + auto res = cli.Get("/get"); + EXPECT_EQ(200, res->status); + } + { + auto res = cli.Get("/redirect/2"); + EXPECT_EQ(200, res->status); + } - Get(requests, "/get"); - Get(requests, "/redirect/2"); + { + std::vector paths = { + "/digest-auth/auth/hello/world/MD5", + "/digest-auth/auth/hello/world/SHA-256", + "/digest-auth/auth/hello/world/SHA-512", + "/digest-auth/auth-int/hello/world/MD5", + }; - std::vector paths = { - "/digest-auth/auth/hello/world/MD5", - "/digest-auth/auth/hello/world/SHA-256", - "/digest-auth/auth/hello/world/SHA-512", - "/digest-auth/auth-int/hello/world/MD5", - }; - - for (auto path : paths) { - Get(requests, path.c_str()); + for (auto path: paths) { + auto res = cli.Get(path.c_str()); + EXPECT_EQ("{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n", res->body); + EXPECT_EQ(200, res->status); + } } { int count = 100; while (count--) { - Get(requests, "/get"); + auto res = cli.Get("/get"); + EXPECT_EQ(200, res->status); } } - - std::vector responses; - auto ret = cli.send(requests, responses); - ASSERT_TRUE(ret == true); - ASSERT_TRUE(requests.size() == responses.size()); - - size_t i = 0; - - { - auto &res = responses[i++]; - EXPECT_EQ(200, res.status); - } - - { - auto &res = responses[i++]; - EXPECT_EQ(200, res.status); - } - - - { - int count = static_cast(paths.size()); - while (count--) { - auto &res = responses[i++]; - EXPECT_EQ("{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n", res.body); - EXPECT_EQ(200, res.status); - } - } - - for (; i < responses.size(); i++) { - auto &res = responses[i]; - EXPECT_EQ(200, res.status); - } } +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT TEST(KeepAliveTest, NoSSLWithBasic) { Client cli("httpbin.org"); KeepAliveTest(cli, true); @@ -292,7 +271,6 @@ TEST(KeepAliveTest, SSLWithBasic) { KeepAliveTest(cli, true); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT TEST(KeepAliveTest, NoSSLWithDigest) { Client cli("httpbin.org"); KeepAliveTest(cli, false);