This commit is contained in:
yhirose 2020-06-16 17:46:23 -04:00
parent 3dfb4ecac2
commit 7cd25fbd63
4 changed files with 304 additions and 336 deletions

View file

@ -482,29 +482,15 @@ httplib::make_range_header({{0, 0}, {-1, 1}}) // 'Range: bytes=0-0, -1'
### Keep-Alive connection
```cpp
cli.set_keep_alive_max_count(2); // Default is 5
httplib::Client cli("localhost", 1234);
std::vector<Request> requests;
Get(requests, "/get-request1");
Get(requests, "/get-request2");
Post(requests, "/post-request1", "text", "text/plain");
Post(requests, "/post-request2", "text", "text/plain");
cli.Get("/hello"); // with "Connection: close"
const size_t DATA_CHUNK_SIZE = 4;
std::string data("abcdefg");
Post(requests, "/post-request-with-content-provider",
data.size(),
[&](size_t offset, size_t length, DataSink &sink){
sink.write(&data[offset], std::min(length, DATA_CHUNK_SIZE));
},
"text/plain");
cli.set_keep_alive(true);
cli.Get("/world");
std::vector<Response> responses;
if (cli.send(requests, responses)) {
for (const auto& res: responses) {
...
}
}
cli.set_keep_alive(false);
cli.Get("/last-request"); // with "Connection: close"
```
### Redirect

444
httplib.h
View file

@ -188,6 +188,7 @@ using socket_t = int;
#include <fcntl.h>
#include <fstream>
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
@ -593,10 +594,11 @@ public:
std::function<TaskQueue *(void)> new_task_queue;
protected:
bool process_request(Stream &strm, bool last_connection,
bool &connection_close,
bool process_request(Stream &strm, bool close_connection,
bool &connection_closed,
const std::function<void(Request &)> &setup_request);
std::atomic<socket_t> svr_sock_;
size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND;
time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND;
@ -624,7 +626,7 @@ private:
HandlersForContentReader &handlers);
bool parse_request_line(const char *s, Request &req);
bool write_response(Stream &strm, bool last_connection, const Request &req,
bool write_response(Stream &strm, bool close_connection, const Request &req,
Response &res);
bool write_content_with_provider(Stream &strm, const Request &req,
Response &res, const std::string &boundary,
@ -643,7 +645,6 @@ private:
virtual bool process_and_close_socket(socket_t sock);
std::atomic<bool> is_running_;
std::atomic<socket_t> svr_sock_;
std::vector<std::pair<std::string, std::string>> base_dirs_;
std::map<std::string, std::string> file_extension_and_mimetype_map_;
Handler file_request_handler_;
@ -797,9 +798,6 @@ public:
bool send(const Request &req, Response &res);
bool send(const std::vector<Request> &requests,
std::vector<Response> &responses);
size_t is_socket_open() const;
void stop();
@ -809,13 +807,12 @@ public:
void set_read_timeout(time_t sec, time_t usec = 0);
void set_write_timeout(time_t sec, time_t usec = 0);
void set_keep_alive_max_count(size_t count);
void set_basic_auth(const char *username, const char *password);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
void set_digest_auth(const char *username, const char *password);
#endif
void set_keep_alive(bool on);
void set_follow_location(bool on);
void set_compress(bool on);
@ -846,7 +843,7 @@ protected:
virtual void close_socket(Socket &socket, bool process_socket_ret);
bool process_request(Stream &strm, const Request &req, Response &res,
bool &connection_close);
bool close_connection);
// Socket endoint information
const std::string host_;
@ -869,8 +866,6 @@ protected:
time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND;
time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND;
size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
std::string basic_auth_username_;
std::string basic_auth_password_;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
@ -878,6 +873,7 @@ protected:
std::string digest_auth_password_;
#endif
bool keep_alive_ = false;
bool follow_location_ = false;
bool compress_ = false;
@ -905,13 +901,13 @@ protected:
read_timeout_usec_ = rhs.read_timeout_usec_;
write_timeout_sec_ = rhs.write_timeout_sec_;
write_timeout_usec_ = rhs.write_timeout_usec_;
keep_alive_max_count_ = rhs.keep_alive_max_count_;
basic_auth_username_ = rhs.basic_auth_username_;
basic_auth_password_ = rhs.basic_auth_password_;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
digest_auth_username_ = rhs.digest_auth_username_;
digest_auth_password_ = rhs.digest_auth_password_;
#endif
keep_alive_ = rhs.keep_alive_;
follow_location_ = rhs.follow_location_;
compress_ = rhs.compress_;
decompress_ = rhs.decompress_;
@ -930,22 +926,18 @@ protected:
private:
socket_t create_client_socket() const;
bool read_response_line(Stream &strm, Response &res);
bool write_request(Stream &strm, const Request &req);
bool write_request(Stream &strm, const Request &req, bool close_connection);
bool redirect(const Request &req, Response &res);
bool handle_request(Stream &strm, const Request &req, Response &res,
bool &connection_close);
bool close_connection);
std::shared_ptr<Response> send_with_content_provider(
const char *method, const char *path, const Headers &headers,
const std::string &body, size_t content_length,
ContentProvider content_provider, const char *content_type);
virtual bool
process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback);
virtual bool process_socket(Socket &socket,
std::function<bool(Stream &strm)> callback);
virtual bool is_ssl() const;
};
@ -1045,15 +1037,13 @@ public:
private:
bool create_and_connect_socket(Socket &socket) override;
bool connect_with_proxy(Socket &sock, bool &error);
void close_socket(Socket &socket, bool process_socket_ret) override;
bool process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) override;
bool process_socket(Socket &socket,
std::function<bool(Stream &strm)> callback) override;
bool is_ssl() const override;
bool connect_with_proxy(Socket &sock, Response &res, bool &success);
bool initialize_ssl(Socket &socket);
bool verify_host(X509 *server_cert) const;
@ -1070,6 +1060,8 @@ private:
X509_STORE *ca_cert_store_ = nullptr;
bool server_certificate_verification_ = false;
long verify_result_ = 0;
friend class Client;
};
#endif
@ -1301,11 +1293,6 @@ public:
bool send(const Request &req, Response &res) { return cli_->send(req, res); }
bool send(const std::vector<Request> &requests,
std::vector<Response> &responses) {
return cli_->send(requests, responses);
}
bool is_socket_open() { return cli_->is_socket_open(); }
void stop() { cli_->stop(); }
@ -1320,11 +1307,6 @@ public:
return *this;
}
Client2 &set_keep_alive_max_count(size_t count) {
cli_->set_keep_alive_max_count(count);
return *this;
}
Client2 &set_basic_auth(const char *username, const char *password) {
cli_->set_basic_auth(username, password);
return *this;
@ -1337,6 +1319,11 @@ public:
}
#endif
Client2 &set_keep_alive(bool on) {
cli_->set_keep_alive(on);
return *this;
}
Client2 &set_follow_location(bool on) {
cli_->set_follow_location(on);
return *this;
@ -1863,49 +1850,75 @@ private:
size_t position = 0;
};
template <typename T>
inline bool process_socket_core(bool is_client_request, socket_t sock,
size_t keep_alive_max_count, T callback) {
inline bool keep_alive(socket_t sock, std::function<bool()> is_shutting_down) {
using namespace std::chrono;
auto start = steady_clock::now();
while (true) {
auto val = select_read(sock, 0, 10000);
if (is_shutting_down && is_shutting_down()) {
return false;
} else if (val < 0) {
return false;
} else if (val == 0) {
auto current = steady_clock::now();
auto sec = duration_cast<seconds>(current - start);
if (sec.count() > CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND) {
return false;
} else if (sec.count() == CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND) {
auto usec = duration_cast<nanoseconds>(current - start);
if (usec.count() > CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) {
return false;
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(1));
} else {
return true;
}
}
}
template <typename T, typename U>
inline bool process_server_socket_core(socket_t sock,
size_t keep_alive_max_count,
T is_shutting_down, U callback) {
assert(keep_alive_max_count > 0);
auto ret = false;
if (keep_alive_max_count > 1) {
auto count = keep_alive_max_count;
while (count > 0 &&
(is_client_request ||
select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND,
CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) {
auto last_connection = count == 1;
auto connection_close = false;
ret = callback(last_connection, connection_close);
if (!ret || connection_close) { break; }
while (count > 0 && keep_alive(sock, is_shutting_down)) {
auto close_connection = count == 1;
auto connection_closed = false;
ret = callback(close_connection, connection_closed);
if (!ret || connection_closed) { break; }
count--;
}
} else { // keep_alive_max_count is 0 or 1
auto dummy_connection_close = false;
ret = callback(true, dummy_connection_close);
}
return ret;
}
template <typename T>
inline bool process_socket(bool is_client_request, socket_t sock,
size_t keep_alive_max_count, time_t read_timeout_sec,
time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, T callback) {
return process_socket_core(
is_client_request, sock, keep_alive_max_count,
[&](bool last_connection, bool connection_close) {
template <typename T, typename U>
inline bool
process_server_socket(socket_t sock, size_t keep_alive_max_count,
time_t read_timeout_sec, time_t read_timeout_usec,
time_t write_timeout_sec, time_t write_timeout_usec,
T is_shutting_down, U callback) {
return process_server_socket_core(
sock, keep_alive_max_count, is_shutting_down,
[&](bool close_connection, bool connection_closed) {
SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
return callback(strm, last_connection, connection_close);
return callback(strm, close_connection, connection_closed);
});
}
template <typename T>
inline bool process_client_socket(socket_t sock, time_t read_timeout_sec,
time_t read_timeout_usec,
time_t write_timeout_sec,
time_t write_timeout_usec, T callback) {
SocketStream strm(sock, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
return callback(strm);
}
inline int shutdown_socket(socket_t sock) {
#ifdef _WIN32
return shutdown(sock, SD_BOTH);
@ -2545,7 +2558,6 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
}
if (!ret) { status = exceed_payload_max_length ? 413 : 400; }
return ret;
}
@ -2582,8 +2594,9 @@ inline bool write_data(Stream &strm, const char *d, size_t l) {
return true;
}
template <typename T>
inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
size_t offset, size_t length) {
size_t offset, size_t length, T is_shutting_down) {
size_t begin_offset = offset;
size_t end_offset = offset + length;
@ -2598,7 +2611,7 @@ inline ssize_t write_content(Stream &strm, ContentProvider content_provider,
};
data_sink.is_writable = [&](void) { return ok && strm.is_writable(); };
while (ok && offset < end_offset) {
while (ok && offset < end_offset && !is_shutting_down()) {
if (!content_provider(offset, end_offset - offset, data_sink)) {
return -1;
}
@ -3110,16 +3123,19 @@ get_multipart_ranges_data_length(const Request &req, Response &res,
return data_length;
}
template <typename T>
inline bool write_multipart_ranges_data(Stream &strm, const Request &req,
Response &res,
const std::string &boundary,
const std::string &content_type) {
const std::string &content_type,
T is_shutting_down) {
return process_multipart_ranges_data(
req, res, boundary, content_type,
[&](const std::string &token) { strm.write(token); },
[&](const char *token) { strm.write(token); },
[&](size_t offset, size_t length) {
return write_content(strm, res.content_provider_, offset, length) >= 0;
return write_content(strm, res.content_provider_, offset, length,
is_shutting_down) >= 0;
});
}
@ -3576,7 +3592,7 @@ inline const std::string &BufferStream::get_buffer() const { return buffer; }
} // namespace detail
// HTTP server implementation
inline Server::Server() : is_running_(false), svr_sock_(INVALID_SOCKET) {
inline Server::Server() : svr_sock_(INVALID_SOCKET), is_running_(false) {
#ifndef _WIN32
signal(SIGPIPE, SIG_IGN);
#endif
@ -3758,7 +3774,7 @@ inline bool Server::parse_request_line(const char *s, Request &req) {
return false;
}
inline bool Server::write_response(Stream &strm, bool last_connection,
inline bool Server::write_response(Stream &strm, bool close_connection,
const Request &req, Response &res) {
assert(res.status != -1);
@ -3773,11 +3789,11 @@ inline bool Server::write_response(Stream &strm, bool last_connection,
}
// Headers
if (last_connection || req.get_header_value("Connection") == "close") {
if (close_connection || req.get_header_value("Connection") == "close") {
res.set_header("Connection", "close");
}
if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") {
if (!close_connection && req.get_header_value("Connection") == "Keep-Alive") {
res.set_header("Connection", "Keep-Alive");
}
@ -3891,10 +3907,14 @@ inline bool
Server::write_content_with_provider(Stream &strm, const Request &req,
Response &res, const std::string &boundary,
const std::string &content_type) {
auto is_shutting_down = [this]() {
return this->svr_sock_ == INVALID_SOCKET;
};
if (res.content_length_) {
if (req.ranges.empty()) {
if (detail::write_content(strm, res.content_provider_, 0,
res.content_length_) < 0) {
res.content_length_, is_shutting_down) < 0) {
return false;
}
} else if (req.ranges.size() == 1) {
@ -3902,20 +3922,17 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
detail::get_range_offset_and_length(req, res.content_length_, 0);
auto offset = offsets.first;
auto length = offsets.second;
if (detail::write_content(strm, res.content_provider_, offset, length) <
0) {
if (detail::write_content(strm, res.content_provider_, offset, length,
is_shutting_down) < 0) {
return false;
}
} else {
if (!detail::write_multipart_ranges_data(strm, req, res, boundary,
content_type)) {
if (!detail::write_multipart_ranges_data(
strm, req, res, boundary, content_type, is_shutting_down)) {
return false;
}
}
} else {
auto is_shutting_down = [this]() {
return this->svr_sock_ == INVALID_SOCKET;
};
if (detail::write_content_chunked(strm, res.content_provider_,
is_shutting_down) < 0) {
return false;
@ -4241,8 +4258,8 @@ inline bool Server::dispatch_request_for_content_reader(
}
inline bool
Server::process_request(Stream &strm, bool last_connection,
bool &connection_close,
Server::process_request(Stream &strm, bool close_connection,
bool &connection_closed,
const std::function<void(Request &)> &setup_request) {
std::array<char, 2048> buf{};
@ -4261,23 +4278,23 @@ Server::process_request(Stream &strm, bool last_connection,
Headers dummy;
detail::read_headers(strm, dummy);
res.status = 414;
return write_response(strm, last_connection, req, res);
return write_response(strm, close_connection, req, res);
}
// Request line and headers
if (!parse_request_line(line_reader.ptr(), req) ||
!detail::read_headers(strm, req.headers)) {
res.status = 400;
return write_response(strm, last_connection, req, res);
return write_response(strm, close_connection, req, res);
}
if (req.get_header_value("Connection") == "close") {
connection_close = true;
connection_closed = true;
}
if (req.version == "HTTP/1.0" &&
req.get_header_value("Connection") != "Keep-Alive") {
connection_close = true;
connection_closed = true;
}
strm.get_remote_ip_and_port(req.remote_addr, req.remote_port);
@ -4304,7 +4321,7 @@ Server::process_request(Stream &strm, bool last_connection,
strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status,
detail::status_message(status));
break;
default: return write_response(strm, last_connection, req, res);
default: return write_response(strm, close_connection, req, res);
}
}
@ -4315,20 +4332,23 @@ Server::process_request(Stream &strm, bool last_connection,
if (res.status == -1) { res.status = 404; }
}
return write_response(strm, last_connection, req, res);
return write_response(strm, close_connection, req, res);
}
inline bool Server::is_valid() const { return true; }
inline bool Server::process_and_close_socket(socket_t sock) {
auto ret = detail::process_socket(
false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
auto ret = detail::process_server_socket(
sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_,
[this](Stream &strm, bool last_connection, bool &connection_close) {
return process_request(strm, last_connection, connection_close,
[this]() { return this->svr_sock_ == INVALID_SOCKET; },
[this](Stream &strm, bool close_connection, bool &connection_closed) {
return process_request(strm, close_connection, connection_closed,
nullptr);
});
std::this_thread::sleep_for(std::chrono::milliseconds(1));
detail::shutdown_socket(sock);
detail::close_socket(sock);
return ret;
}
@ -4347,12 +4367,7 @@ inline Client::Client(const std::string &host, int port,
host_and_port_(host_ + ":" + std::to_string(port_)),
client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
inline Client::~Client() {
assert(socket_.sock == INVALID_SOCKET);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
assert(socket_.ssl == nullptr);
#endif
}
inline Client::~Client() { stop(); }
inline bool Client::is_valid() const { return true; }
@ -4402,63 +4417,49 @@ inline bool Client::read_response_line(Stream &strm, Response &res) {
inline bool Client::send(const Request &req, Response &res) {
std::lock_guard<std::recursive_mutex> request_mutex_guard(request_mutex_);
auto need_new_socket = !is_socket_open();
if (need_new_socket) {
std::lock_guard<std::mutex> guard(socket_mutex_);
if (!create_and_connect_socket(socket_)) { return false; }
}
auto ret = process_socket(
socket_, 1,
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
return handle_request(strm, req, res, connection_close);
});
if (need_new_socket) {
std::lock_guard<std::mutex> guard(socket_mutex_);
if (socket_.is_open()) { close_socket(socket_, ret); }
}
return ret;
}
inline bool Client::send(const std::vector<Request> &requests,
std::vector<Response> &responses) {
std::lock_guard<std::recursive_mutex> request_mutex_guard(request_mutex_);
size_t i = 0;
while (i < requests.size()) {
{
std::lock_guard<std::mutex> guard(socket_mutex_);
if (!create_and_connect_socket(socket_)) { return false; }
}
auto request_count = (std::min)(requests.size() - i, keep_alive_max_count_);
auto ret = process_socket(
socket_, request_count,
[&](Stream &strm, bool /*last_connection*/, bool &connection_close) {
auto &req = requests[i++];
auto res = Response();
auto ret = handle_request(strm, req, res, connection_close);
if (ret) { responses.emplace_back(std::move(res)); }
return ret;
});
{
std::lock_guard<std::mutex> guard(socket_mutex_);
if (socket_.is_open()) { close_socket(socket_, ret); }
auto is_alive = false;
if (socket_.is_open()) {
is_alive = detail::select_write(socket_.sock, 0, 0) > 0;
if (!is_alive) { close_socket(socket_, false); }
}
if (!ret) { return false; }
if (!is_alive) {
if (!create_and_connect_socket(socket_)) { return false; }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
// TODO: refactoring
if (is_ssl()) {
auto &scli = static_cast<SSLClient &>(*this);
if (!proxy_host_.empty()) {
bool success = false;
if (!scli.connect_with_proxy(socket_, res, success)) {
return success;
}
}
return true;
if (!scli.initialize_ssl(socket_)) { return false; }
}
#endif
}
}
auto close_connection = !keep_alive_;
auto ret = process_socket(socket_, [&](Stream &strm) {
return handle_request(strm, req, res, close_connection);
});
if (close_connection) { stop(); }
return ret;
}
inline bool Client::handle_request(Stream &strm, const Request &req,
Response &res, bool &connection_close) {
Response &res, bool close_connection) {
if (req.path.empty()) { return false; }
bool ret;
@ -4466,9 +4467,9 @@ inline bool Client::handle_request(Stream &strm, const Request &req,
if (!is_ssl() && !proxy_host_.empty()) {
auto req2 = req;
req2.path = "http://" + host_and_port_ + req.path;
ret = process_request(strm, req2, res, connection_close);
ret = process_request(strm, req2, res, close_connection);
} else {
ret = process_request(strm, req, res, connection_close);
ret = process_request(strm, req, res, close_connection);
}
if (!ret) { return false; }
@ -4558,7 +4559,8 @@ inline bool Client::redirect(const Request &req, Response &res) {
}
}
inline bool Client::write_request(Stream &strm, const Request &req) {
inline bool Client::write_request(Stream &strm, const Request &req,
bool close_connection) {
detail::BufferStream bstrm;
// Request line
@ -4568,6 +4570,8 @@ inline bool Client::write_request(Stream &strm, const Request &req) {
// Additonal headers
Headers headers;
if (close_connection) { headers.emplace("Connection", "close"); }
if (!req.has_header("Host")) {
if (is_ssl()) {
if (port_ == 443) {
@ -4710,9 +4714,9 @@ inline std::shared_ptr<Response> Client::send_with_content_provider(
}
inline bool Client::process_request(Stream &strm, const Request &req,
Response &res, bool &connection_close) {
Response &res, bool close_connection) {
// Send request
if (!write_request(strm, req)) { return false; }
if (!write_request(strm, req, close_connection)) { return false; }
// Receive response and headers
if (!read_response_line(strm, res) ||
@ -4720,11 +4724,6 @@ inline bool Client::process_request(Stream &strm, const Request &req,
return false;
}
if (res.get_header_value("Connection") == "close" ||
res.version == "HTTP/1.0") {
connection_close = true;
}
if (req.response_handler) {
if (!req.response_handler(res)) { return false; }
}
@ -4749,20 +4748,22 @@ inline bool Client::process_request(Stream &strm, const Request &req,
}
}
if (res.get_header_value("Connection") == "close" ||
res.version == "HTTP/1.0") {
stop();
}
// Log
if (logger_) { logger_(req, res); }
return true;
}
inline bool
Client::process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) {
return detail::process_socket(
true, socket.sock, request_count, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, callback);
inline bool Client::process_socket(Socket &socket,
std::function<bool(Stream &strm)> callback) {
return detail::process_client_socket(socket.sock, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_, callback);
}
inline bool Client::is_ssl() const { return false; }
@ -5066,9 +5067,9 @@ inline void Client::stop() {
std::lock_guard<std::mutex> guard(socket_mutex_);
if (socket_.is_open()) {
detail::shutdown_socket(socket_.sock);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
std::this_thread::sleep_for(std::chrono::milliseconds(1));
close_socket(socket_, true);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
@ -5091,10 +5092,6 @@ inline void Client::set_write_timeout(time_t sec, time_t usec) {
write_timeout_usec_ = usec;
}
inline void Client::set_keep_alive_max_count(size_t count) {
keep_alive_max_count_ = count;
}
inline void Client::set_basic_auth(const char *username, const char *password) {
basic_auth_username_ = username;
basic_auth_password_ = password;
@ -5108,6 +5105,8 @@ inline void Client::set_digest_auth(const char *username,
}
#endif
inline void Client::set_keep_alive(bool on) { keep_alive_ = on; }
inline void Client::set_follow_location(bool on) { follow_location_ = on; }
inline void Client::set_compress(bool on) { compress_ = on; }
@ -5181,19 +5180,29 @@ inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
template <typename T>
inline bool
process_socket_ssl(SSL *ssl, bool is_client_request, socket_t sock,
size_t keep_alive_max_count, time_t read_timeout_sec,
time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, T callback) {
return process_socket_core(
is_client_request, sock, keep_alive_max_count,
[&](bool last_connection, bool connection_close) {
process_server_socket_ssl(SSL *ssl, socket_t sock, size_t keep_alive_max_count,
time_t read_timeout_sec, time_t read_timeout_usec,
time_t write_timeout_sec, time_t write_timeout_usec,
std::function<bool()> is_shutting_down, T callback) {
return process_server_socket_core(
sock, keep_alive_max_count, is_shutting_down,
[&](bool close_connection, bool connection_closed) {
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
return callback(strm, last_connection, connection_close);
return callback(strm, close_connection, connection_closed);
});
}
template <typename T>
inline bool
process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec,
time_t read_timeout_usec, time_t write_timeout_sec,
time_t write_timeout_usec, T callback) {
SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec,
write_timeout_sec, write_timeout_usec);
return callback(strm);
}
#if OPENSSL_VERSION_NUMBER < 0x10100000L
static std::shared_ptr<std::vector<std::mutex>> openSSL_locks_;
@ -5365,12 +5374,13 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
[](SSL * /*ssl*/) { return true; });
if (ssl) {
auto ret = detail::process_socket_ssl(
ssl, false, sock, keep_alive_max_count_, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
[this, ssl](Stream &strm, bool last_connection,
bool &connection_close) {
return process_request(strm, last_connection, connection_close,
auto ret = detail::process_server_socket_ssl(
ssl, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_,
[this]() { return this->svr_sock_ == INVALID_SOCKET; },
[this, ssl](Stream &strm, bool close_connection,
bool &connection_closed) {
return process_request(strm, close_connection, connection_closed,
[&](Request &req) { req.ssl = ssl; });
});
@ -5455,49 +5465,36 @@ inline long SSLClient::get_openssl_verify_result() const {
inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; }
inline bool SSLClient::create_and_connect_socket(Socket &socket) {
if (is_valid() && Client::create_and_connect_socket(socket) &&
initialize_ssl(socket)) {
if (!proxy_host_.empty()) {
bool error;
if (!connect_with_proxy(socket, error)) { return error; }
}
return true;
}
return false;
return is_valid() && Client::create_and_connect_socket(socket);
}
inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) {
error = true;
Response res;
inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
bool &success) {
success = true;
Response res2;
if (!detail::process_socket_core(
true, socket.sock, 1,
[&](bool /*last_connection*/, bool &connection_close) {
detail::SocketStream strm(socket.sock, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_,
write_timeout_usec_);
if (!detail::process_client_socket(
socket.sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
Request req2;
req2.method = "CONNECT";
req2.path = host_and_port_;
return process_request(strm, req2, res, connection_close);
return process_request(strm, req2, res2, false);
})) {
close_socket(socket, true);
error = false;
success = false;
return false;
}
if (res.status == 407) {
if (res2.status == 407) {
if (!proxy_digest_auth_username_.empty() &&
!proxy_digest_auth_password_.empty()) {
std::map<std::string, std::string> auth;
if (parse_www_authenticate(res, auth, true)) {
if (parse_www_authenticate(res2, auth, true)) {
Response res3;
if (!detail::process_socket_core(
true, socket.sock, 1,
[&](bool /*last_connection*/, bool &connection_close) {
detail::SocketStream strm(
if (!detail::process_client_socket(
socket.sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_);
write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
Request req3;
req3.method = "CONNECT";
req3.path = host_and_port_;
@ -5505,14 +5502,15 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, bool &error) {
req3, auth, 1, random_string(10),
proxy_digest_auth_username_, proxy_digest_auth_password_,
true));
return process_request(strm, req3, res3, connection_close);
return process_request(strm, req3, res3, false);
})) {
close_socket(socket, true);
error = false;
success = false;
return false;
}
}
} else {
res = res2;
return false;
}
}
@ -5583,17 +5581,12 @@ inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) {
}
inline bool
SSLClient::process_socket(Socket &socket, size_t request_count,
std::function<bool(Stream &strm, bool last_connection,
bool &connection_close)>
callback) {
SSLClient::process_socket(Socket &socket,
std::function<bool(Stream &strm)> callback) {
assert(socket.ssl);
return detail::process_socket_ssl(
socket.ssl, true, socket.sock, request_count, read_timeout_sec_,
read_timeout_usec_, write_timeout_sec_, write_timeout_usec_,
[&](Stream &strm, bool last_connection, bool &connection_close) {
return callback(strm, last_connection, connection_close);
});
return detail::process_client_socket_ssl(
socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, callback);
}
inline bool SSLClient::is_ssl() const { return true; }
@ -5678,7 +5671,6 @@ SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const {
}
GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names);
return ret;
}

View file

@ -1136,6 +1136,10 @@ protected:
EXPECT_EQ(req.get_param_value("key"), "value");
EXPECT_EQ(req.body, "content");
})
.Get("/last-request",
[&](const Request & req, Response &/*res*/) {
EXPECT_EQ("close", req.get_header_value("Connection"));
})
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
.Get("/gzip",
[&](const Request & /*req*/, Response &res) {
@ -2127,42 +2131,48 @@ TEST_F(ServerTest, HTTP2Magic) {
}
TEST_F(ServerTest, KeepAlive) {
cli_.set_keep_alive_max_count(4);
auto res = cli_.Get("/hi");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
EXPECT_EQ("Hello World!", res->body);
std::vector<Request> requests;
Get(requests, "/hi");
Get(requests, "/hi");
Get(requests, "/hi");
Get(requests, "/not-exist");
Post(requests, "/empty", "", "text/plain");
Post(
requests, "/empty", 0,
[&](size_t, size_t, httplib::DataSink &) { return true; }, "text/plain");
res = cli_.Get("/hi");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
EXPECT_EQ("Hello World!", res->body);
std::vector<Response> responses;
auto ret = cli_.send(requests, responses);
res = cli_.Get("/hi");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
EXPECT_EQ("Hello World!", res->body);
ASSERT_TRUE(ret == true);
ASSERT_TRUE(requests.size() == responses.size());
res = cli_.Get("/not-exist");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(404, res->status);
for (size_t i = 0; i < 3; i++) {
auto &res = responses[i];
EXPECT_EQ(200, res.status);
EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
EXPECT_EQ("Hello World!", res.body);
}
res = cli_.Post("/empty", "", "text/plain");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
EXPECT_EQ("empty", res->body);
EXPECT_EQ("close", res->get_header_value("Connection"));
{
auto &res = responses[3];
EXPECT_EQ(404, res.status);
}
res = cli_.Post(
"/empty", 0, [&](size_t, size_t, httplib::DataSink &) { return true; },
"text/plain");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ("text/plain", res->get_header_value("Content-Type"));
EXPECT_EQ("empty", res->body);
for (size_t i = 4; i < 6; i++) {
auto &res = responses[i];
EXPECT_EQ(200, res.status);
EXPECT_EQ("text/plain", res.get_header_value("Content-Type"));
EXPECT_EQ("empty", res.body);
}
cli_.set_keep_alive(false);
res = cli_.Get("/last-request");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(200, res->status);
EXPECT_EQ("close", res->get_header_value("Connection"));
}
#ifdef CPPHTTPLIB_ZLIB_SUPPORT
@ -2310,10 +2320,8 @@ static bool send_request(time_t read_timeout_sec, const std::string &req,
if (client_sock == INVALID_SOCKET) { return false; }
auto ret = detail::process_socket(
true, client_sock, 1, read_timeout_sec, 0, 0, 0,
[&](Stream &strm, bool /*last_connection*/, bool &
/*connection_close*/) -> bool {
auto ret = detail::process_client_socket(
client_sock, read_timeout_sec, 0, 0, 0, [&](Stream &strm) {
if (req.size() !=
static_cast<size_t>(strm.write(req.data(), req.size()))) {
return false;
@ -2515,8 +2523,7 @@ TEST(ServerStopTest, StopServerWithChunkedTransmission) {
}
Client client(HOST, PORT);
const Headers headers = {{"Accept", "text/event-stream"},
{"Connection", "Keep-Alive"}};
const Headers headers = {{"Accept", "text/event-stream"}};
auto get_thread = std::thread([&client, &headers]() {
std::shared_ptr<Response> res = client.Get(
@ -2742,19 +2749,24 @@ TEST(SSLClientTest, ServerNameIndication) {
ASSERT_EQ(200, res->status);
}
TEST(SSLClientTest, ServerCertificateVerification) {
TEST(SSLClientTest, ServerCertificateVerification1) {
SSLClient cli("google.com");
auto res = cli.Get("/");
ASSERT_TRUE(res != nullptr);
ASSERT_EQ(301, res->status);
}
TEST(SSLClientTest, ServerCertificateVerification2) {
SSLClient cli("google.com");
cli.enable_server_certificate_verification(true);
res = cli.Get("/");
auto res = cli.Get("/");
ASSERT_TRUE(res == nullptr);
}
TEST(SSLClientTest, ServerCertificateVerification3) {
SSLClient cli("google.com");
cli.set_ca_cert_path(CA_CERT_FILE);
res = cli.Get("/");
auto res = cli.Get("/");
ASSERT_TRUE(res != nullptr);
ASSERT_EQ(301, res->status);
}

View file

@ -222,15 +222,21 @@ void KeepAliveTest(Client& cli, bool basic) {
#endif
}
cli.set_keep_alive_max_count(4);
cli.set_follow_location(true);
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
cli.set_digest_auth("hello", "world");
#endif
std::vector<Request> requests;
Get(requests, "/get");
Get(requests, "/redirect/2");
{
auto res = cli.Get("/get");
EXPECT_EQ(200, res->status);
}
{
auto res = cli.Get("/redirect/2");
EXPECT_EQ(200, res->status);
}
{
std::vector<std::string> paths = {
"/digest-auth/auth/hello/world/MD5",
"/digest-auth/auth/hello/world/SHA-256",
@ -238,50 +244,23 @@ void KeepAliveTest(Client& cli, bool basic) {
"/digest-auth/auth-int/hello/world/MD5",
};
for (auto path : paths) {
Get(requests, path.c_str());
for (auto path: paths) {
auto res = cli.Get(path.c_str());
EXPECT_EQ("{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n", res->body);
EXPECT_EQ(200, res->status);
}
}
{
int count = 100;
while (count--) {
Get(requests, "/get");
auto res = cli.Get("/get");
EXPECT_EQ(200, res->status);
}
}
std::vector<Response> responses;
auto ret = cli.send(requests, responses);
ASSERT_TRUE(ret == true);
ASSERT_TRUE(requests.size() == responses.size());
size_t i = 0;
{
auto &res = responses[i++];
EXPECT_EQ(200, res.status);
}
{
auto &res = responses[i++];
EXPECT_EQ(200, res.status);
}
{
int count = static_cast<int>(paths.size());
while (count--) {
auto &res = responses[i++];
EXPECT_EQ("{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n", res.body);
EXPECT_EQ(200, res.status);
}
}
for (; i < responses.size(); i++) {
auto &res = responses[i];
EXPECT_EQ(200, res.status);
}
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(KeepAliveTest, NoSSLWithBasic) {
Client cli("httpbin.org");
KeepAliveTest(cli, true);
@ -292,7 +271,6 @@ TEST(KeepAliveTest, SSLWithBasic) {
KeepAliveTest(cli, true);
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(KeepAliveTest, NoSSLWithDigest) {
Client cli("httpbin.org");
KeepAliveTest(cli, false);