Refactoring to make it ready for KeepAlive connection on Client

This commit is contained in:
yhirose 2020-06-13 21:42:23 -04:00
parent 34282c79a9
commit e022b8b80b
2 changed files with 169 additions and 195 deletions

353
httplib.h
View file

@ -800,7 +800,9 @@ public:
bool send(const std::vector<Request> &requests, bool send(const std::vector<Request> &requests,
std::vector<Response> &responses); std::vector<Response> &responses);
virtual void stop(); size_t is_socket_open() const;
void stop();
CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec); CPPHTTPLIB_DEPRECATED void set_timeout_sec(time_t timeout_sec);
void set_connection_timeout(time_t sec, time_t usec = 0); void set_connection_timeout(time_t sec, time_t usec = 0);
@ -831,26 +833,31 @@ public:
void set_logger(Logger logger); void set_logger(Logger logger);
protected: protected:
struct Endpoint { struct Socket {
socket_t sock = INVALID_SOCKET; socket_t sock = INVALID_SOCKET;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
SSL *ssl = nullptr; SSL *ssl = nullptr;
#endif #endif
bool is_open() const { return sock != INVALID_SOCKET; }
}; };
virtual bool create_and_connect_socket(Endpoint &endpoint); virtual bool create_and_connect_socket(Socket &socket);
virtual void close_socket(Endpoint &endpoint, bool process_socket_ret); virtual void close_socket(Socket &socket, bool process_socket_ret);
bool process_request(Stream &strm, const Request &req, Response &res, bool process_request(Stream &strm, const Request &req, Response &res,
bool last_connection, bool &connection_close); bool &connection_close);
std::vector<Endpoint> endpoints_;
std::mutex endpoints_mutex_;
// Socket endoint information
const std::string host_; const std::string host_;
const int port_; const int port_;
const std::string host_and_port_; const std::string host_and_port_;
// Current open socket
Socket socket_;
mutable std::mutex socket_mutex_;
std::recursive_mutex request_mutex_;
// Settings // Settings
std::string client_cert_path_; std::string client_cert_path_;
std::string client_key_path_; std::string client_key_path_;
@ -923,13 +930,10 @@ protected:
private: private:
socket_t create_client_socket() const; socket_t create_client_socket() const;
bool read_response_line(Stream &strm, Response &res); bool read_response_line(Stream &strm, Response &res);
bool write_request(Stream &strm, const Request &req, bool last_connection); bool write_request(Stream &strm, const Request &req);
bool redirect(const Request &req, Response &res); bool redirect(const Request &req, Response &res);
bool handle_request(Stream &strm, const Request &req, Response &res, bool handle_request(Stream &strm, const Request &req, Response &res,
bool last_connection, bool &connection_close); bool &connection_close);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
bool connect_with_proxy(socket_t sock, Response &res, bool &error);
#endif
std::shared_ptr<Response> send_with_content_provider( std::shared_ptr<Response> send_with_content_provider(
const char *method, const char *path, const Headers &headers, const char *method, const char *path, const Headers &headers,
@ -937,7 +941,7 @@ private:
ContentProvider content_provider, const char *content_type); ContentProvider content_provider, const char *content_type);
virtual bool virtual bool
process_socket(Endpoint &endpoint, size_t request_count, process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback); callback);
@ -1026,8 +1030,6 @@ public:
~SSLClient() override; ~SSLClient() override;
void stop() override;
bool is_valid() const override; bool is_valid() const override;
void set_ca_cert_path(const char *ca_cert_file_path, void set_ca_cert_path(const char *ca_cert_file_path,
@ -1042,16 +1044,17 @@ public:
SSL_CTX *ssl_context() const; SSL_CTX *ssl_context() const;
private: private:
bool create_and_connect_socket(Endpoint &endpoint) override; bool create_and_connect_socket(Socket &socket) override;
void close_socket(Endpoint &endpoint, bool process_socket_ret) override; bool connect_with_proxy(Socket &sock, bool &error);
void close_socket(Socket &socket, bool process_socket_ret) override;
bool process_socket(Endpoint &endpoint, size_t request_count, bool process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback) override; callback) override;
bool is_ssl() const override; bool is_ssl() const override;
bool initialize_ssl(Endpoint &endpoint); bool initialize_ssl(Socket &socket);
bool verify_host(X509 *server_cert) const; bool verify_host(X509 *server_cert) const;
bool verify_host_with_subject_alt_name(X509 *server_cert) const; bool verify_host_with_subject_alt_name(X509 *server_cert) const;
@ -1303,6 +1306,8 @@ public:
return cli_->send(requests, responses); return cli_->send(requests, responses);
} }
bool is_socket_open() { return cli_->is_socket_open(); }
void stop() { cli_->stop(); } void stop() { cli_->stop(); }
Client2 &set_connection_timeout(time_t sec, time_t usec) { Client2 &set_connection_timeout(time_t sec, time_t usec) {
@ -4330,7 +4335,12 @@ inline Client::Client(const std::string &host, int port,
host_and_port_(host_ + ":" + std::to_string(port_)), host_and_port_(host_ + ":" + std::to_string(port_)),
client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
inline Client::~Client() {} inline Client::~Client() {
assert(socket_.sock == INVALID_SOCKET);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
assert(socket_.ssl == nullptr);
#endif
}
inline bool Client::is_valid() const { return true; } inline bool Client::is_valid() const { return true; }
@ -4345,24 +4355,19 @@ inline socket_t Client::create_client_socket() const {
connection_timeout_usec_, interface_); connection_timeout_usec_, interface_);
} }
inline bool Client::create_and_connect_socket(Endpoint &endpoint) { inline bool Client::create_and_connect_socket(Socket &socket) {
auto sock = create_client_socket(); auto sock = create_client_socket();
if (sock == INVALID_SOCKET) { return false; } if (sock == INVALID_SOCKET) { return false; }
socket.sock = sock;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (is_ssl() && !proxy_host_.empty()) {
Response res;
bool error;
if (!connect_with_proxy(sock, res, error)) { return error; }
}
#endif
endpoint.sock = sock;
return true; return true;
} }
inline void Client::close_socket(Endpoint &endpoint, inline void Client::close_socket(Socket &socket, bool /*process_socket_ret*/) {
bool /*process_socket_ret*/) { detail::close_socket(socket.sock);
detail::close_socket(endpoint.sock); socket_.sock = INVALID_SOCKET;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
socket_.ssl = nullptr;
#endif
} }
inline bool Client::read_response_line(Stream &strm, Response &res) { inline bool Client::read_response_line(Stream &strm, Response &res) {
@ -4384,32 +4389,23 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
} }
inline bool Client::send(const Request &req, Response &res) { inline bool Client::send(const Request &req, Response &res) {
Endpoint endpoint; std::lock_guard<std::recursive_mutex> guard(request_mutex_);
if (!create_and_connect_socket(endpoint)) { return false; } auto need_new_socket = !is_socket_open();
{ if (need_new_socket) {
std::lock_guard<std::mutex> guard(endpoints_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
endpoints_.push_back(endpoint); if (!create_and_connect_socket(socket_)) { return false; }
} }
auto ret = process_socket( auto ret = process_socket(
endpoint, 1, socket_, 1,
[&](Stream &strm, bool last_connection, bool &connection_close) { [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
return handle_request(strm, req, res, last_connection, return handle_request(strm, req, res, connection_close);
connection_close);
}); });
{ if (need_new_socket) {
std::lock_guard<std::mutex> guard(endpoints_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
if (socket_.is_open()) { close_socket(socket_, ret); }
auto it = std::find_if(
endpoints_.begin(), endpoints_.end(),
[&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
if (it != endpoints_.end()) {
close_socket(endpoint, ret);
endpoints_.erase(it);
}
} }
return ret; return ret;
@ -4417,43 +4413,30 @@ inline bool Client::send(const Request &req, Response &res) {
inline bool Client::send(const std::vector<Request> &requests, inline bool Client::send(const std::vector<Request> &requests,
std::vector<Response> &responses) { std::vector<Response> &responses) {
std::lock_guard<std::recursive_mutex> guard(request_mutex_);
size_t i = 0; size_t i = 0;
while (i < requests.size()) { while (i < requests.size()) {
Endpoint endpoint;
if (!create_and_connect_socket(endpoint)) { return false; }
{ {
std::lock_guard<std::mutex> guard(endpoints_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
endpoints_.push_back(endpoint); if (!create_and_connect_socket(socket_)) { return false; }
} }
auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_); auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_);
auto ret = process_socket(endpoint, request_count, auto ret = process_socket(
[&](Stream &strm, bool last_connection, socket_, request_count,
bool &connection_close) -> bool { [&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
auto &req = requests[i++]; auto &req = requests[i++];
auto res = Response(); auto res = Response();
auto ret = handle_request(strm, req, res, auto ret = handle_request(strm, req, res, connection_close);
last_connection, if (ret) { responses.emplace_back(std::move(res)); }
connection_close); return ret;
if (ret) { });
responses.emplace_back(std::move(res));
}
return ret;
});
{ {
std::lock_guard<std::mutex> guard(endpoints_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
if (socket_.is_open()) { close_socket(socket_, ret); }
auto it = std::find_if(
endpoints_.begin(), endpoints_.end(),
[&](Endpoint &endpoint2) { return endpoint.sock == endpoint2.sock; });
if (it != endpoints_.end()) {
close_socket(endpoint, ret);
endpoints_.erase(it);
}
} }
if (!ret) { return false; } if (!ret) { return false; }
@ -4463,8 +4446,7 @@ inline bool Client::send(const std::vector<Request> &requests,
} }
inline bool Client::handle_request(Stream &strm, const Request &req, inline bool Client::handle_request(Stream &strm, const Request &req,
Response &res, bool last_connection, Response &res, bool &connection_close) {
bool &connection_close) {
if (req.path.empty()) { return false; } if (req.path.empty()) { return false; }
bool ret; bool ret;
@ -4472,9 +4454,9 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
if (!is_ssl() && !proxy_host_.empty()) { if (!is_ssl() && !proxy_host_.empty()) {
auto req2 = req; auto req2 = req;
req2.path = "http://" + host_and_port_ + req.path; req2.path = "http://" + host_and_port_ + req.path;
ret = process_request(strm, req2, res, last_connection, connection_close); ret = process_request(strm, req2, res, connection_close);
} else { } else {
ret = process_request(strm, req, res, last_connection, connection_close); ret = process_request(strm, req, res, connection_close);
} }
if (!ret) { return false; } if (!ret) { return false; }
@ -4515,64 +4497,6 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
return ret; return ret;
} }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
inline bool Client::connect_with_proxy(socket_t sock, Response &res,
bool &error) {
error = true;
Response res2;
if (!detail::process_socket_core(
true, sock, 1, [&](bool /*last_connection*/, bool &connection_close) {
detail::SocketStream strm(sock, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_);
Request req2;
req2.method = "CONNECT";
req2.path = host_and_port_;
return process_request(strm, req2, res2, false, connection_close);
})) {
detail::close_socket(sock);
error = false;
return false;
}
if (res2.status == 407) {
if (!proxy_digest_auth_username_.empty() &&
!proxy_digest_auth_password_.empty()) {
std::map<std::string, std::string> auth;
if (parse_www_authenticate(res2, auth, true)) {
Response res3;
if (!detail::process_socket_core(
true, sock, 1,
[&](bool /*last_connection*/, bool &connection_close) {
detail::SocketStream strm(
sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_);
Request req3;
req3.method = "CONNECT";
req3.path = host_and_port_;
req3.headers.insert(make_digest_authentication_header(
req3, auth, 1, random_string(10),
proxy_digest_auth_username_, proxy_digest_auth_password_,
true));
return process_request(strm, req3, res3, false,
connection_close);
})) {
detail::close_socket(sock);
error = false;
return false;
}
}
} else {
res = res2;
return false;
}
}
return true;
}
#endif
inline bool Client::redirect(const Request &req, Response &res) { inline bool Client::redirect(const Request &req, Response &res) {
if (req.redirect_count == 0) { return false; } if (req.redirect_count == 0) { return false; }
@ -4622,8 +4546,7 @@ inline bool Client::redirect(const Request &req, Response &res) {
} }
} }
inline bool Client::write_request(Stream &strm, const Request &req, inline bool Client::write_request(Stream &strm, const Request &req) {
bool last_connection) {
detail::BufferStream bstrm; detail::BufferStream bstrm;
// Request line // Request line
@ -4633,8 +4556,6 @@ inline bool Client::write_request(Stream &strm, const Request &req,
// Additonal headers // Additonal headers
Headers headers; Headers headers;
if (last_connection) { headers.emplace("Connection", "close"); }
if (!req.has_header("Host")) { if (!req.has_header("Host")) {
if (is_ssl()) { if (is_ssl()) {
if (port_ == 443) { if (port_ == 443) {
@ -4777,10 +4698,9 @@ inline std::shared_ptr<Response> Client::send_with_content_provider(
} }
inline bool Client::process_request(Stream &strm, const Request &req, inline bool Client::process_request(Stream &strm, const Request &req,
Response &res, bool last_connection, Response &res, bool &connection_close) {
bool &connection_close) {
// Send request // Send request
if (!write_request(strm, req, last_connection)) { return false; } if (!write_request(strm, req)) { return false; }
// Receive response and headers // Receive response and headers
if (!read_response_line(strm, res) || if (!read_response_line(strm, res) ||
@ -4824,12 +4744,12 @@ inline bool Client::process_request(Stream &strm, const Request &req,
} }
inline bool inline bool
Client::process_socket(Endpoint &endpoint, size_t request_count, Client::process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback) { callback) {
return detail::process_socket( return detail::process_socket(
true, endpoint.sock, request_count, read_timeout_sec_, read_timeout_usec_, true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, callback); write_timeout_sec_, write_timeout_usec_, callback);
} }
@ -5125,13 +5045,17 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
return send(req, *res) ? res : nullptr; return send(req, *res) ? res : nullptr;
} }
inline size_t Client::is_socket_open() const {
std::lock_guard<std::mutex> guard(socket_mutex_);
return socket_.is_open();
}
inline void Client::stop() { inline void Client::stop() {
std::lock_guard<std::mutex> guard(endpoints_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
for (auto &endpoint : endpoints_) { if (socket_.is_open()) {
detail::shutdown_socket(endpoint.sock); detail::shutdown_socket(socket_.sock);
detail::close_socket(endpoint.sock); close_socket(socket_, true);
} }
endpoints_.clear();
} }
inline void Client::set_timeout_sec(time_t timeout_sec) { inline void Client::set_timeout_sec(time_t timeout_sec) {
@ -5494,25 +5418,6 @@ inline SSLClient::~SSLClient() {
if (ctx_) { SSL_CTX_free(ctx_); } if (ctx_) { SSL_CTX_free(ctx_); }
} }
inline void SSLClient::stop() {
auto endpoints = endpoints_;
{
std::lock_guard<std::mutex> guard(endpoints_mutex_);
for (auto &endpoint : endpoints_) {
detail::shutdown_socket(endpoint.sock);
detail::close_socket(endpoint.sock);
}
endpoints_.clear();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
for (auto &endpoint : endpoints) {
SSL_shutdown(endpoint.ssl);
SSL_free(endpoint.ssl);
}
}
inline bool SSLClient::is_valid() const { return ctx_; } inline bool SSLClient::is_valid() const { return ctx_; }
inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path,
@ -5535,14 +5440,75 @@ inline long SSLClient::get_openssl_verify_result() const {
inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
inline bool SSLClient::create_and_connect_socket(Endpoint &endpoint) { inline bool SSLClient::create_and_connect_socket(Socket &socket) {
return is_valid() && Client::create_and_connect_socket(endpoint) && if (is_valid() && Client::create_and_connect_socket(socket) &&
initialize_ssl(endpoint); initialize_ssl(socket)) {
if (!proxy_host_.empty()) {
bool error;
if (!connect_with_proxy(socket, error)) { return error; }
}
return true;
}
return false;
} }
inline bool SSLClient::initialize_ssl(Endpoint &endpoint) { inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) {
error = true;
Response res;
if (!detail::process_socket_core(
true, socket.sock, 1,
[&](bool /*last_connection*/, bool &connection_close) {
detail::SocketStream strm(socket.sock, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_);
Request req2;
req2.method = "CONNECT";
req2.path = host_and_port_;
return process_request(strm, req2, res, connection_close);
})) {
close_socket(socket, true);
error = false;
return false;
}
if (res.status == 407) {
if (!proxy_digest_auth_username_.empty() &&
!proxy_digest_auth_password_.empty()) {
std::map<std::string, std::string> auth;
if (parse_www_authenticate(res, auth, true)) {
Response res3;
if (!detail::process_socket_core(
true, socket.sock, 1,
[&](bool /*last_connection*/, bool &connection_close) {
detail::SocketStream strm(
socket.sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_);
Request req3;
req3.method = "CONNECT";
req3.path = host_and_port_;
req3.headers.insert(make_digest_authentication_header(
req3, auth, 1, random_string(10),
proxy_digest_auth_username_, proxy_digest_auth_password_,
true));
return process_request(strm, req3, res3, connection_close);
})) {
close_socket(socket, true);
error = false;
return false;
}
}
} else {
return false;
}
}
return true;
}
inline bool SSLClient::initialize_ssl(Socket &socket) {
auto ssl = detail::ssl_new( auto ssl = detail::ssl_new(
endpoint.sock, ctx_, ctx_mutex_, socket.sock, ctx_, ctx_mutex_,
[&](SSL *ssl) { [&](SSL *ssl) {
if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) { if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) {
SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
@ -5585,29 +5551,32 @@ inline bool SSLClient::initialize_ssl(Endpoint &endpoint) {
}); });
if (ssl) { if (ssl) {
endpoint.ssl = ssl; socket.ssl = ssl;
return true; return true;
} }
detail::close_socket(endpoint.sock); close_socket(socket, false);
return false; return false;
} }
inline void SSLClient::close_socket(Endpoint &endpoint, inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) {
bool process_socket_ret) { detail::close_socket(socket.sock);
assert(endpoint.ssl); socket_.sock = INVALID_SOCKET;
detail::ssl_delete(ctx_mutex_, endpoint.ssl, process_socket_ret); std::this_thread::sleep_for(std::chrono::milliseconds(10));
detail::close_socket(endpoint.sock); if (socket.ssl) {
detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret);
socket_.ssl = nullptr;
}
} }
inline bool inline bool
SSLClient::process_socket(Endpoint &endpoint, size_t request_count, SSLClient::process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection, std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)> bool &connection_close)>
callback) { callback) {
assert(endpoint.ssl); assert(socket.ssl);
return detail::process_socket_ssl( return detail::process_socket_ssl(
endpoint.ssl, true, endpoint.sock, request_count, read_timeout_sec_, socket.ssl, true, socket.sock, request_count, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
[&](Stream &strm, bool last_connection, bool &connection_close) { [&](Stream &strm, bool last_connection, bool &connection_close) {
return callback(strm, last_connection, connection_close); return callback(strm, last_connection, connection_close);

View file

@ -1767,15 +1767,20 @@ TEST_F(ServerTest, GetStreamedEndless) {
TEST_F(ServerTest, ClientStop) { TEST_F(ServerTest, ClientStop) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
for (auto i = 0; i < 3; i++) { for (auto i = 0; i < 100; i++) {
threads.emplace_back(thread([&]() { threads.emplace_back(thread([&]() {
auto res = cli_.Get("/streamed-cancel", auto res = cli_.Get("/streamed-cancel",
[&](const char *, uint64_t) { return true; }); [&](const char *, uint64_t) { return true; });
ASSERT_TRUE(res == nullptr); ASSERT_TRUE(res == nullptr);
})); }));
} }
std::this_thread::sleep_for(std::chrono::seconds(3));
cli_.stop(); std::this_thread::sleep_for(std::chrono::seconds(1));
while (cli_.is_socket_open()) {
cli_.stop();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
for (auto &t : threads) { for (auto &t : threads) {
t.join(); t.join();
} }