Keep-alive connection support on client (Fix #36)

This commit is contained in:
yhirose 2019-08-31 09:06:24 -04:00
parent a4160e6ac1
commit 1e82359329
2 changed files with 171 additions and 55 deletions

190
httplib.h
View file

@ -171,6 +171,9 @@ struct Request {
Ranges ranges; Ranges ranges;
Match matches; Match matches;
ContentReceiver content_receiver;
Progress progress;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
const SSL *ssl; const SSL *ssl;
#endif #endif
@ -195,10 +198,6 @@ struct Response {
Headers headers; Headers headers;
std::string body; std::string body;
ContentReceiver content_receiver;
Progress progress;
bool has_header(const char *key) const; bool has_header(const char *key) const;
std::string get_header_value(const char *key, size_t id = 0) const; std::string get_header_value(const char *key, size_t id = 0) const;
size_t get_header_value_count(const char *key) const; size_t get_header_value_count(const char *key) const;
@ -456,7 +455,7 @@ private:
Response &res, const std::string &boundary, Response &res, const std::string &boundary,
const std::string &content_type); const std::string &content_type);
virtual bool read_and_close_socket(socket_t sock); virtual bool process_and_close_socket(socket_t sock);
std::atomic<bool> is_running_; std::atomic<bool> is_running_;
std::atomic<socket_t> svr_sock_; std::atomic<socket_t> svr_sock_;
@ -533,6 +532,10 @@ public:
bool send(Request &req, Response &res); bool send(Request &req, Response &res);
bool send(std::vector<Request> &requests, std::vector<Response>& responses);
void set_keep_alive_max_count(size_t count);
protected: protected:
bool process_request(Stream &strm, Request &req, Response &res, bool process_request(Stream &strm, Request &req, Response &res,
bool &connection_close); bool &connection_close);
@ -541,17 +544,48 @@ protected:
const int port_; const int port_;
time_t timeout_sec_; time_t timeout_sec_;
const std::string host_and_port_; const std::string host_and_port_;
size_t keep_alive_max_count_;
private: private:
socket_t create_client_socket() const; socket_t create_client_socket() const;
bool read_response_line(Stream &strm, Response &res); bool read_response_line(Stream &strm, Response &res);
void write_request(Stream &strm, Request &req); void write_request(Stream &strm, Request &req);
virtual bool read_and_close_socket(socket_t sock, Request &req, virtual bool process_and_close_socket(
Response &res); socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback);
virtual bool is_ssl() const; virtual bool is_ssl() const;
}; };
inline void Get(std::vector<Request> &requests, const char *path, const Headers &headers) {
Request req;
req.method = "GET";
req.path = path;
req.headers = headers;
requests.emplace_back(std::move(req));
}
inline void Get(std::vector<Request> &requests, const char *path) {
Get(requests, path, Headers());
}
inline void Post(std::vector<Request> &requests, const char *path, const Headers &headers, const std::string &body, const char *content_type) {
Request req;
req.method = "POST";
req.path = path;
req.headers = headers;
req.headers.emplace("Content-Type", content_type);
req.body = body;
requests.emplace_back(std::move(req));
}
inline void Post(std::vector<Request> &requests, const char *path, const std::string &body, const char *content_type) {
Post(requests, path, Headers(), body, content_type);
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
class SSLSocketStream : public Stream { class SSLSocketStream : public Stream {
public: public:
@ -580,7 +614,7 @@ public:
virtual bool is_valid() const; virtual bool is_valid() const;
private: private:
virtual bool read_and_close_socket(socket_t sock); virtual bool process_and_close_socket(socket_t sock);
SSL_CTX *ctx_; SSL_CTX *ctx_;
std::mutex ctx_mutex_; std::mutex ctx_mutex_;
@ -603,8 +637,11 @@ public:
long get_openssl_verify_result() const; long get_openssl_verify_result() const;
private: private:
virtual bool read_and_close_socket(socket_t sock, Request &req, virtual bool process_and_close_socket(
Response &res); socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback);
virtual bool is_ssl() const; virtual bool is_ssl() const;
bool verify_host(X509 *server_cert) const; bool verify_host(X509 *server_cert) const;
@ -928,15 +965,18 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) {
} }
template <typename T> template <typename T>
inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count, inline bool process_and_close_socket(bool is_client_request, socket_t sock,
T callback) { size_t keep_alive_max_count, T callback) {
assert(keep_alive_max_count > 0);
bool ret = false; bool ret = false;
if (keep_alive_max_count > 0) { if (keep_alive_max_count > 1) {
auto count = keep_alive_max_count; auto count = keep_alive_max_count;
while (count > 0 && while (count > 0 &&
detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, (is_client_request ||
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
SocketStream strm(sock); SocketStream strm(sock);
auto last_connection = count == 1; auto last_connection = count == 1;
auto connection_close = false; auto connection_close = false;
@ -2315,9 +2355,7 @@ inline bool Server::handle_file_request(Request &req, Response &res) {
auto type = detail::find_content_type(path); auto type = detail::find_content_type(path);
if (type) { res.set_header("Content-Type", type); } if (type) { res.set_header("Content-Type", type); }
res.status = 200; res.status = 200;
if (file_request_handler_) { if (file_request_handler_) { file_request_handler_(req, res); }
file_request_handler_(req, res);
}
return true; return true;
} }
} }
@ -2398,7 +2436,7 @@ inline bool Server::listen_internal() {
break; break;
} }
task_queue->enqueue([=]() { read_and_close_socket(sock); }); task_queue->enqueue([=]() { process_and_close_socket(sock); });
} }
task_queue->shutdown(); task_queue->shutdown();
@ -2528,9 +2566,9 @@ Server::process_request(Stream &strm, bool last_connection,
inline bool Server::is_valid() const { return true; } inline bool Server::is_valid() const { return true; }
inline bool Server::read_and_close_socket(socket_t sock) { inline bool Server::process_and_close_socket(socket_t sock) {
return detail::read_and_close_socket( return detail::process_and_close_socket(
sock, keep_alive_max_count_, false, sock, keep_alive_max_count_,
[this](Stream &strm, bool last_connection, bool &connection_close) { [this](Stream &strm, bool last_connection, bool &connection_close) {
return process_request(strm, last_connection, connection_close, return process_request(strm, last_connection, connection_close,
nullptr); nullptr);
@ -2540,7 +2578,8 @@ inline bool Server::read_and_close_socket(socket_t sock) {
// HTTP client implementation // HTTP client implementation
inline Client::Client(const char *host, int port, time_t timeout_sec) inline Client::Client(const char *host, int port, time_t timeout_sec)
: host_(host), port_(port), timeout_sec_(timeout_sec), : host_(host), port_(port), timeout_sec_(timeout_sec),
host_and_port_(host_ + ":" + std::to_string(port_)) {} host_and_port_(host_ + ":" + std::to_string(port_)),
keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT) {}
inline Client::~Client() {} inline Client::~Client() {}
@ -2590,7 +2629,37 @@ inline bool Client::send(Request &req, Response &res) {
auto sock = create_client_socket(); auto sock = create_client_socket();
if (sock == INVALID_SOCKET) { return false; } if (sock == INVALID_SOCKET) { return false; }
return read_and_close_socket(sock, req, res); return process_and_close_socket(
sock, 1,
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
return process_request(strm, req, res, connection_close);
});
}
inline bool Client::send(std::vector<Request> &requests, std::vector<Response>& responses) {
size_t i = 0;
while (i < requests.size()) {
auto sock = create_client_socket();
if (sock == INVALID_SOCKET) { return false; }
if (!process_and_close_socket(
sock, requests.size() - i,
[&](Stream &strm, bool last_connection, bool &connection_close) {
auto &req = requests[i];
auto res = Response();
i++;
if (req.path.empty()) { return false; }
if (last_connection) { req.set_header("Connection", "close"); }
auto ret = process_request(strm, req, res, connection_close);
if (ret) { responses.emplace_back(std::move(res)); }
return ret;
})) {
return false;
}
}
return true;
} }
inline void Client::write_request(Stream &strm, Request &req) { inline void Client::write_request(Stream &strm, Request &req) {
@ -2677,10 +2746,10 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
return true; return true;
}; };
if (res.content_receiver) { if (req.content_receiver) {
auto offset = std::make_shared<uint64_t>(); auto offset = std::make_shared<uint64_t>();
auto length = get_header_value_uint64(res.headers, "Content-Length", 0); auto length = get_header_value_uint64(res.headers, "Content-Length", 0);
auto receiver = res.content_receiver; auto receiver = req.content_receiver;
out = [offset, length, receiver](const char *buf, size_t n) { out = [offset, length, receiver](const char *buf, size_t n) {
auto ret = receiver(buf, n, *offset, length); auto ret = receiver(buf, n, *offset, length);
(*offset) += n; (*offset) += n;
@ -2690,7 +2759,7 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
int dummy_status; int dummy_status;
if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(), if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(),
dummy_status, res.progress, out)) { dummy_status, req.progress, out)) {
return false; return false;
} }
} }
@ -2698,13 +2767,13 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
return true; return true;
} }
inline bool Client::read_and_close_socket(socket_t sock, Request &req, inline bool Client::process_and_close_socket(
Response &res) { socket_t sock, size_t request_count,
return detail::read_and_close_socket( std::function<bool(Stream &strm, bool last_connection,
sock, 0, bool &connection_close)>
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) { callback) {
return process_request(strm, req, res, connection_close); request_count = std::min(request_count, keep_alive_max_count_);
}); return detail::process_and_close_socket(true, sock, request_count, callback);
} }
inline bool Client::is_ssl() const { return false; } inline bool Client::is_ssl() const { return false; }
@ -2720,10 +2789,9 @@ Client::Get(const char *path, const Headers &headers, Progress progress) {
req.method = "GET"; req.method = "GET";
req.path = path; req.path = path;
req.headers = headers; req.headers = headers;
req.progress = progress;
auto res = std::make_shared<Response>(); auto res = std::make_shared<Response>();
res->progress = progress;
return send(req, *res) ? res : nullptr; return send(req, *res) ? res : nullptr;
} }
@ -2741,11 +2809,10 @@ inline std::shared_ptr<Response> Client::Get(const char *path,
req.method = "GET"; req.method = "GET";
req.path = path; req.path = path;
req.headers = headers; req.headers = headers;
req.content_receiver = content_receiver;
req.progress = progress;
auto res = std::make_shared<Response>(); auto res = std::make_shared<Response>();
res->content_receiver = content_receiver;
res->progress = progress;
return send(req, *res) ? res : nullptr; return send(req, *res) ? res : nullptr;
} }
@ -2930,6 +2997,10 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
return send(req, *res) ? res : nullptr; return send(req, *res) ? res : nullptr;
} }
inline void Client::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
}
/* /*
* SSL Implementation * SSL Implementation
*/ */
@ -2937,10 +3008,13 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
namespace detail { namespace detail {
template <typename U, typename V, typename T> template <typename U, typename V, typename T>
inline bool inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock,
read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count, size_t keep_alive_max_count,
SSL_CTX *ctx, std::mutex &ctx_mutex, SSL_CTX *ctx, std::mutex &ctx_mutex,
U SSL_connect_or_accept, V setup, T callback) { U SSL_connect_or_accept, V setup,
T callback) {
assert(keep_alive_max_count > 0);
SSL *ssl = nullptr; SSL *ssl = nullptr;
{ {
std::lock_guard<std::mutex> guard(ctx_mutex); std::lock_guard<std::mutex> guard(ctx_mutex);
@ -2969,11 +3043,12 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
bool ret = false; bool ret = false;
if (SSL_connect_or_accept(ssl) == 1) { if (SSL_connect_or_accept(ssl) == 1) {
if (keep_alive_max_count > 0) { if (keep_alive_max_count > 1) {
auto count = keep_alive_max_count; auto count = keep_alive_max_count;
while (count > 0 && while (count > 0 &&
detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, (is_client_request ||
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) { detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
SSLSocketStream strm(sock, ssl); SSLSocketStream strm(sock, ssl);
auto last_connection = count == 1; auto last_connection = count == 1;
auto connection_close = false; auto connection_close = false;
@ -3123,9 +3198,9 @@ inline SSLServer::~SSLServer() {
inline bool SSLServer::is_valid() const { return ctx_; } inline bool SSLServer::is_valid() const { return ctx_; }
inline bool SSLServer::read_and_close_socket(socket_t sock) { inline bool SSLServer::process_and_close_socket(socket_t sock) {
return detail::read_and_close_socket_ssl( return detail::process_and_close_socket_ssl(
sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept, false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
[](SSL * /*ssl*/) { return true; }, [](SSL * /*ssl*/) { return true; },
[this](SSL *ssl, Stream &strm, bool last_connection, [this](SSL *ssl, Stream &strm, bool last_connection,
bool &connection_close) { bool &connection_close) {
@ -3176,12 +3251,17 @@ inline long SSLClient::get_openssl_verify_result() const {
return verify_result_; return verify_result_;
} }
inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req, inline bool SSLClient::process_and_close_socket(
Response &res) { socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) {
request_count = std::min(request_count, keep_alive_max_count_);
return is_valid() && return is_valid() &&
detail::read_and_close_socket_ssl( detail::process_and_close_socket_ssl(
sock, 0, ctx_, ctx_mutex_, true, sock, request_count, ctx_, ctx_mutex_,
[&](SSL *ssl) { [&](SSL *ssl) {
if (ca_cert_file_path_.empty()) { if (ca_cert_file_path_.empty()) {
SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
@ -3217,9 +3297,9 @@ inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req,
SSL_set_tlsext_host_name(ssl, host_.c_str()); SSL_set_tlsext_host_name(ssl, host_.c_str());
return true; return true;
}, },
[&](SSL * /*ssl*/, Stream &strm, bool /*last_connection*/, [&](SSL * /*ssl*/, Stream &strm, bool last_connection,
bool &connection_close) { bool &connection_close) {
return process_request(strm, req, res, connection_close); return callback(strm, last_connection, connection_close);
}); });
} }

View file

@ -1280,6 +1280,42 @@ TEST_F(ServerTest, NoMultipleHeaders) {
EXPECT_EQ(200, res->status); EXPECT_EQ(200, res->status);
} }
TEST_F(ServerTest, KeepAlive) {
cli_.set_keep_alive_max_count(4);
std::vector<Request> requests;
Get(requests, "/hi");
Get(requests, "/hi");
Get(requests, "/hi");
Get(requests, "/not-exist");
Post(requests, "/empty", "", "text/plain");
std::vector<Response> responses;
auto ret = cli_.send(requests, responses);
ASSERT_TRUE(ret == true);
ASSERT_TRUE(requests.size() == responses.size());
for (int i = 0; i < 3; i++) {
auto& res = responses[i];
EXPECT_EQ(200, res.status);
EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
EXPECT_EQ("Hello World!", res.body);
}
{
auto& res = responses[3];
EXPECT_EQ(404, res.status);
}
{
auto& res = responses[4];
EXPECT_EQ(200, res.status);
EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
EXPECT_EQ("empty", res.body);
}
}
#ifdef CPPHTTPLIB_ZLIB_SUPPORT #ifdef CPPHTTPLIB_ZLIB_SUPPORT
TEST_F(ServerTest, Gzip) { TEST_F(ServerTest, Gzip) {
Headers headers; Headers headers;