Keep-alive connection support on client (Fix #36)

This commit is contained in:
yhirose 2019-08-31 09:06:24 -04:00
parent a4160e6ac1
commit 1e82359329
2 changed files with 171 additions and 55 deletions

190
httplib.h
View file

@ -171,6 +171,9 @@ struct Request {
Ranges ranges;
Match matches;
ContentReceiver content_receiver;
Progress progress;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
const SSL *ssl;
#endif
@ -195,10 +198,6 @@ struct Response {
Headers headers;
std::string body;
ContentReceiver content_receiver;
Progress progress;
bool has_header(const char *key) const;
std::string get_header_value(const char *key, size_t id = 0) const;
size_t get_header_value_count(const char *key) const;
@ -456,7 +455,7 @@ private:
Response &res, const std::string &boundary,
const std::string &content_type);
virtual bool read_and_close_socket(socket_t sock);
virtual bool process_and_close_socket(socket_t sock);
std::atomic<bool> is_running_;
std::atomic<socket_t> svr_sock_;
@ -533,6 +532,10 @@ public:
bool send(Request &req, Response &res);
bool send(std::vector<Request> &requests, std::vector<Response>& responses);
void set_keep_alive_max_count(size_t count);
protected:
bool process_request(Stream &strm, Request &req, Response &res,
bool &connection_close);
@ -541,17 +544,48 @@ protected:
const int port_;
time_t timeout_sec_;
const std::string host_and_port_;
size_t keep_alive_max_count_;
private:
socket_t create_client_socket() const;
bool read_response_line(Stream &strm, Response &res);
void write_request(Stream &strm, Request &req);
virtual bool read_and_close_socket(socket_t sock, Request &req,
Response &res);
virtual bool process_and_close_socket(
socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback);
virtual bool is_ssl() const;
};
inline void Get(std::vector<Request> &requests, const char *path, const Headers &headers) {
Request req;
req.method = "GET";
req.path = path;
req.headers = headers;
requests.emplace_back(std::move(req));
}
inline void Get(std::vector<Request> &requests, const char *path) {
Get(requests, path, Headers());
}
inline void Post(std::vector<Request> &requests, const char *path, const Headers &headers, const std::string &body, const char *content_type) {
Request req;
req.method = "POST";
req.path = path;
req.headers = headers;
req.headers.emplace("Content-Type", content_type);
req.body = body;
requests.emplace_back(std::move(req));
}
inline void Post(std::vector<Request> &requests, const char *path, const std::string &body, const char *content_type) {
Post(requests, path, Headers(), body, content_type);
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
class SSLSocketStream : public Stream {
public:
@ -580,7 +614,7 @@ public:
virtual bool is_valid() const;
private:
virtual bool read_and_close_socket(socket_t sock);
virtual bool process_and_close_socket(socket_t sock);
SSL_CTX *ctx_;
std::mutex ctx_mutex_;
@ -603,8 +637,11 @@ public:
long get_openssl_verify_result() const;
private:
virtual bool read_and_close_socket(socket_t sock, Request &req,
Response &res);
virtual bool process_and_close_socket(
socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback);
virtual bool is_ssl() const;
bool verify_host(X509 *server_cert) const;
@ -928,15 +965,18 @@ inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) {
}
template <typename T>
inline bool read_and_close_socket(socket_t sock, size_t keep_alive_max_count,
T callback) {
inline bool process_and_close_socket(bool is_client_request, socket_t sock,
size_t keep_alive_max_count, T callback) {
assert(keep_alive_max_count > 0);
bool ret = false;
if (keep_alive_max_count > 0) {
if (keep_alive_max_count > 1) {
auto count = keep_alive_max_count;
while (count > 0 &&
detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) {
(is_client_request ||
detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
SocketStream strm(sock);
auto last_connection = count == 1;
auto connection_close = false;
@ -2315,9 +2355,7 @@ inline bool Server::handle_file_request(Request &req, Response &res) {
auto type = detail::find_content_type(path);
if (type) { res.set_header("Content-Type", type); }
res.status = 200;
if (file_request_handler_) {
file_request_handler_(req, res);
}
if (file_request_handler_) { file_request_handler_(req, res); }
return true;
}
}
@ -2398,7 +2436,7 @@ inline bool Server::listen_internal() {
break;
}
task_queue->enqueue([=]() { read_and_close_socket(sock); });
task_queue->enqueue([=]() { process_and_close_socket(sock); });
}
task_queue->shutdown();
@ -2528,9 +2566,9 @@ Server::process_request(Stream &strm, bool last_connection,
inline bool Server::is_valid() const { return true; }
inline bool Server::read_and_close_socket(socket_t sock) {
return detail::read_and_close_socket(
sock, keep_alive_max_count_,
inline bool Server::process_and_close_socket(socket_t sock) {
return detail::process_and_close_socket(
false, sock, keep_alive_max_count_,
[this](Stream &strm, bool last_connection, bool &connection_close) {
return process_request(strm, last_connection, connection_close,
nullptr);
@ -2540,7 +2578,8 @@ inline bool Server::read_and_close_socket(socket_t sock) {
// HTTP client implementation
inline Client::Client(const char *host, int port, time_t timeout_sec)
: host_(host), port_(port), timeout_sec_(timeout_sec),
host_and_port_(host_ + ":" + std::to_string(port_)) {}
host_and_port_(host_ + ":" + std::to_string(port_)),
keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT) {}
inline Client::~Client() {}
@ -2590,7 +2629,37 @@ inline bool Client::send(Request &req, Response &res) {
auto sock = create_client_socket();
if (sock == INVALID_SOCKET) { return false; }
return read_and_close_socket(sock, req, res);
return process_and_close_socket(
sock, 1,
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
return process_request(strm, req, res, connection_close);
});
}
inline bool Client::send(std::vector<Request> &requests, std::vector<Response>& responses) {
size_t i = 0;
while (i < requests.size()) {
auto sock = create_client_socket();
if (sock == INVALID_SOCKET) { return false; }
if (!process_and_close_socket(
sock, requests.size() - i,
[&](Stream &strm, bool last_connection, bool &connection_close) {
auto &req = requests[i];
auto res = Response();
i++;
if (req.path.empty()) { return false; }
if (last_connection) { req.set_header("Connection", "close"); }
auto ret = process_request(strm, req, res, connection_close);
if (ret) { responses.emplace_back(std::move(res)); }
return ret;
})) {
return false;
}
}
return true;
}
inline void Client::write_request(Stream &strm, Request &req) {
@ -2677,10 +2746,10 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
return true;
};
if (res.content_receiver) {
if (req.content_receiver) {
auto offset = std::make_shared<uint64_t>();
auto length = get_header_value_uint64(res.headers, "Content-Length", 0);
auto receiver = res.content_receiver;
auto receiver = req.content_receiver;
out = [offset, length, receiver](const char *buf, size_t n) {
auto ret = receiver(buf, n, *offset, length);
(*offset) += n;
@ -2690,7 +2759,7 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
int dummy_status;
if (!detail::read_content(strm, res, std::numeric_limits<uint64_t>::max(),
dummy_status, res.progress, out)) {
dummy_status, req.progress, out)) {
return false;
}
}
@ -2698,13 +2767,13 @@ inline bool Client::process_request(Stream &strm, Request &req, Response &res,
return true;
}
inline bool Client::read_and_close_socket(socket_t sock, Request &req,
Response &res) {
return detail::read_and_close_socket(
sock, 0,
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
return process_request(strm, req, res, connection_close);
});
inline bool Client::process_and_close_socket(
socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) {
request_count = std::min(request_count, keep_alive_max_count_);
return detail::process_and_close_socket(true, sock, request_count, callback);
}
inline bool Client::is_ssl() const { return false; }
@ -2720,10 +2789,9 @@ Client::Get(const char *path, const Headers &headers, Progress progress) {
req.method = "GET";
req.path = path;
req.headers = headers;
req.progress = progress;
auto res = std::make_shared<Response>();
res->progress = progress;
return send(req, *res) ? res : nullptr;
}
@ -2741,11 +2809,10 @@ inline std::shared_ptr<Response> Client::Get(const char *path,
req.method = "GET";
req.path = path;
req.headers = headers;
req.content_receiver = content_receiver;
req.progress = progress;
auto res = std::make_shared<Response>();
res->content_receiver = content_receiver;
res->progress = progress;
return send(req, *res) ? res : nullptr;
}
@ -2930,6 +2997,10 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
return send(req, *res) ? res : nullptr;
}
inline void Client::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
}
/*
* SSL Implementation
*/
@ -2937,10 +3008,13 @@ inline std::shared_ptr<Response> Client::Options(const char *path,
namespace detail {
template <typename U, typename V, typename T>
inline bool
read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
SSL_CTX *ctx, std::mutex &ctx_mutex,
U SSL_connect_or_accept, V setup, T callback) {
inline bool process_and_close_socket_ssl(bool is_client_request, socket_t sock,
size_t keep_alive_max_count,
SSL_CTX *ctx, std::mutex &ctx_mutex,
U SSL_connect_or_accept, V setup,
T callback) {
assert(keep_alive_max_count > 0);
SSL *ssl = nullptr;
{
std::lock_guard<std::mutex> guard(ctx_mutex);
@ -2969,11 +3043,12 @@ read_and_close_socket_ssl(socket_t sock, size_t keep_alive_max_count,
bool ret = false;
if (SSL_connect_or_accept(ssl) == 1) {
if (keep_alive_max_count > 0) {
if (keep_alive_max_count > 1) {
auto count = keep_alive_max_count;
while (count > 0 &&
detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0) {
(is_client_request ||
detail::select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
SSLSocketStream strm(sock, ssl);
auto last_connection = count == 1;
auto connection_close = false;
@ -3123,9 +3198,9 @@ inline SSLServer::~SSLServer() {
inline bool SSLServer::is_valid() const { return ctx_; }
inline bool SSLServer::read_and_close_socket(socket_t sock) {
return detail::read_and_close_socket_ssl(
sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
inline bool SSLServer::process_and_close_socket(socket_t sock) {
return detail::process_and_close_socket_ssl(
false, sock, keep_alive_max_count_, ctx_, ctx_mutex_, SSL_accept,
[](SSL * /*ssl*/) { return true; },
[this](SSL *ssl, Stream &strm, bool last_connection,
bool &connection_close) {
@ -3176,12 +3251,17 @@ inline long SSLClient::get_openssl_verify_result() const {
return verify_result_;
}
inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req,
Response &res) {
inline bool SSLClient::process_and_close_socket(
socket_t sock, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) {
request_count = std::min(request_count, keep_alive_max_count_);
return is_valid() &&
detail::read_and_close_socket_ssl(
sock, 0, ctx_, ctx_mutex_,
detail::process_and_close_socket_ssl(
true, sock, request_count, ctx_, ctx_mutex_,
[&](SSL *ssl) {
if (ca_cert_file_path_.empty()) {
SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr);
@ -3217,9 +3297,9 @@ inline bool SSLClient::read_and_close_socket(socket_t sock, Request &req,
SSL_set_tlsext_host_name(ssl, host_.c_str());
return true;
},
[&](SSL * /*ssl*/, Stream &strm, bool /*last_connection*/,
[&](SSL * /*ssl*/, Stream &strm, bool last_connection,
bool &connection_close) {
return process_request(strm, req, res, connection_close);
return callback(strm, last_connection, connection_close);
});
}

View file

@ -1280,6 +1280,42 @@ TEST_F(ServerTest, NoMultipleHeaders) {
EXPECT_EQ(200, res->status);
}
TEST_F(ServerTest, KeepAlive) {
cli_.set_keep_alive_max_count(4);
std::vector<Request> requests;
Get(requests, "/hi");
Get(requests, "/hi");
Get(requests, "/hi");
Get(requests, "/not-exist");
Post(requests, "/empty", "", "text/plain");
std::vector<Response> responses;
auto ret = cli_.send(requests, responses);
ASSERT_TRUE(ret == true);
ASSERT_TRUE(requests.size() == responses.size());
for (int i = 0; i < 3; i++) {
auto& res = responses[i];
EXPECT_EQ(200, res.status);
EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
EXPECT_EQ("Hello World!", res.body);
}
{
auto& res = responses[3];
EXPECT_EQ(404, res.status);
}
{
auto& res = responses[4];
EXPECT_EQ(200, res.status);
EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
EXPECT_EQ("empty", res.body);
}
}
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
TEST_F(ServerTest, Gzip) {
Headers headers;