This commit is contained in:
yhirose 2021-02-02 22:09:35 -05:00
parent 0542fdb8e4
commit b7566f6961
2 changed files with 180 additions and 103 deletions

232
httplib.h
View file

@ -390,6 +390,9 @@ struct Request {
Match matches;
// for client
ResponseHandler response_handler;
ContentReceiverWithProgress content_receiver;
Progress progress;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
const SSL *ssl;
#endif
@ -413,12 +416,9 @@ struct Request {
// private members...
size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT;
ResponseHandler response_handler_;
ContentReceiverWithProgress content_receiver_;
size_t content_length_ = 0;
ContentProvider content_provider_;
bool is_chunked_content_provider_ = false;
Progress progress_;
size_t authorization_count_ = 0;
};
@ -794,8 +794,11 @@ enum Error {
class Result {
public:
Result(std::unique_ptr<Response> res, Error err)
: res_(std::move(res)), err_(err) {}
Result(std::unique_ptr<Response> &&res, Error err,
Headers &&request_headers = Headers{})
: res_(std::move(res)), err_(err),
request_headers_(std::move(request_headers)) {}
// Response
operator bool() const { return res_ != nullptr; }
bool operator==(std::nullptr_t) const { return res_ == nullptr; }
bool operator!=(std::nullptr_t) const { return res_ != nullptr; }
@ -805,11 +808,21 @@ public:
Response &operator*() { return *res_; }
const Response *operator->() const { return res_.get(); }
Response *operator->() { return res_.get(); }
// Error
Error error() const { return err_; }
// Request Headers
bool has_request_header(const char *key) const;
std::string get_request_header_value(const char *key, size_t id = 0) const;
template <typename T>
T get_request_header_value(const char *key, size_t id = 0) const;
size_t get_request_header_value_count(const char *key) const;
private:
std::unique_ptr<Response> res_;
Error err_;
Headers request_headers_;
};
class ClientImpl {
@ -939,7 +952,7 @@ public:
Result Options(const char *path);
Result Options(const char *path, const Headers &headers);
bool send(const Request &req, Response &res, Error &error);
bool send(Request &req, Response &res, Error &error);
Result send(const Request &req);
size_t is_socket_open() const;
@ -993,6 +1006,8 @@ protected:
bool is_open() const { return sock != INVALID_SOCKET; }
};
Result send_(Request &&req);
virtual bool create_and_connect_socket(Socket &socket, Error &error);
// All of:
@ -1010,7 +1025,7 @@ protected:
// concurrently with a DIFFERENT thread sending requests from the socket
void lock_socket_and_shutdown_and_close();
bool process_request(Stream &strm, const Request &req, Response &res,
bool process_request(Stream &strm, Request &req, Response &res,
bool close_connection, Error &error);
bool write_content_with_provider(Stream &strm, const Request &req,
@ -1086,13 +1101,14 @@ protected:
private:
socket_t create_client_socket(Error &error) const;
bool read_response_line(Stream &strm, const Request &req, Response &res);
bool write_request(Stream &strm, const Request &req, bool close_connection,
bool write_request(Stream &strm, Request &req, bool close_connection,
Error &error);
bool redirect(const Request &req, Response &res, Error &error);
bool handle_request(Stream &strm, const Request &req, Response &res,
bool redirect(Request &req, Response &res, Error &error);
bool handle_request(Stream &strm, Request &req, Response &res,
bool close_connection, Error &error);
std::unique_ptr<Response> send_with_content_provider(
const char *method, const char *path, const Headers &headers,
Request &req,
// const char *method, const char *path, const Headers &headers,
const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const char *content_type, Error &error);
@ -1238,7 +1254,7 @@ public:
Result Options(const char *path);
Result Options(const char *path, const Headers &headers);
bool send(const Request &req, Response &res, Error &error);
bool send(Request &req, Response &res, Error &error);
Result send(const Request &req);
size_t is_socket_open() const;
@ -2922,17 +2938,8 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status,
});
}
template <typename T>
inline ssize_t write_headers(Stream &strm, const T &info,
const Headers &headers) {
inline ssize_t write_headers(Stream &strm, const Headers &headers) {
ssize_t write_len = 0;
for (const auto &x : info.headers) {
if (x.first == "EXCEPTION_WHAT") { continue; }
auto len =
strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
if (len < 0) { return len; }
write_len += len;
}
for (const auto &x : headers) {
auto len =
strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str());
@ -3119,7 +3126,7 @@ inline bool write_content_chunked(Stream &strm,
}
template <typename T>
inline bool redirect(T &cli, const Request &req, Response &res,
inline bool redirect(T &cli, Request &req, Response &res,
const std::string &path, const std::string &location,
Error &error) {
Request new_req = req;
@ -3136,8 +3143,9 @@ inline bool redirect(T &cli, const Request &req, Response &res,
auto ret = cli.send(new_req, new_res, error);
if (ret) {
new_res.location = location;
req = new_req;
res = new_res;
res.location = location;
}
return ret;
}
@ -3978,7 +3986,27 @@ inline void Response::set_chunked_content_provider(
is_chunked_content_provider_ = true;
}
// Rstream implementation
// Result implementation
inline bool Result::has_request_header(const char *key) const {
return request_headers_.find(key) != request_headers_.end();
}
inline std::string Result::get_request_header_value(const char *key,
size_t id) const {
return detail::get_header_value(request_headers_, key, id, "");
}
template <typename T>
inline T Result::get_request_header_value(const char *key, size_t id) const {
return detail::get_header_value<T>(request_headers_, key, id, 0);
}
inline size_t Result::get_request_header_value_count(const char *key) const {
auto r = request_headers_.equal_range(key);
return static_cast<size_t>(std::distance(r.first, r.second));
}
// Stream implementation
inline ssize_t Stream::write(const char *ptr) {
return write(ptr, strlen(ptr));
}
@ -4473,7 +4501,7 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
return false;
}
if (!detail::write_headers(bstrm, res, Headers())) { return false; }
if (!detail::write_headers(bstrm, res.headers)) { return false; }
// Flush buffer
auto &data = bstrm.get_buffer();
@ -4795,26 +4823,26 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm) {
if (req.method == "POST") {
if (dispatch_request_for_content_reader(
req, res, std::move(reader),
post_handlers_for_content_reader_)) {
req, res, std::move(reader),
post_handlers_for_content_reader_)) {
return true;
}
} else if (req.method == "PUT") {
if (dispatch_request_for_content_reader(
req, res, std::move(reader),
put_handlers_for_content_reader_)) {
req, res, std::move(reader),
put_handlers_for_content_reader_)) {
return true;
}
} else if (req.method == "PATCH") {
if (dispatch_request_for_content_reader(
req, res, std::move(reader),
patch_handlers_for_content_reader_)) {
req, res, std::move(reader),
patch_handlers_for_content_reader_)) {
return true;
}
} else if (req.method == "DELETE") {
if (dispatch_request_for_content_reader(
req, res, std::move(reader),
delete_handlers_for_content_reader_)) {
req, res, std::move(reader),
delete_handlers_for_content_reader_)) {
return true;
}
}
@ -5069,7 +5097,7 @@ Server::process_request(Stream &strm, bool close_connection,
bool routed = false;
try {
routed = routing(req, res, strm);
} catch (std::exception & e) {
} catch (std::exception &e) {
if (exception_handler_) {
exception_handler_(req, res, e);
routed = true;
@ -5253,7 +5281,7 @@ inline bool ClientImpl::read_response_line(Stream &strm, const Request &req,
return true;
}
inline bool ClientImpl::send(const Request &req, Response &res, Error &error) {
inline bool ClientImpl::send(Request &req, Response &res, Error &error) {
std::lock_guard<std::recursive_mutex> request_mutex_guard(request_mutex_);
{
@ -5306,6 +5334,12 @@ inline bool ClientImpl::send(const Request &req, Response &res, Error &error) {
socket_requests_are_from_thread_ = std::this_thread::get_id();
}
for (const auto &header : default_headers_) {
if (req.headers.find(header.first) == req.headers.end()) {
req.headers.insert(header);
}
}
auto close_connection = !keep_alive_;
auto ret = process_socket(socket_, [&](Stream &strm) {
return handle_request(strm, req, res, close_connection, error);
@ -5336,13 +5370,18 @@ inline bool ClientImpl::send(const Request &req, Response &res, Error &error) {
}
inline Result ClientImpl::send(const Request &req) {
auto req2 = req;
return send_(std::move(req2));
}
inline Result ClientImpl::send_(Request &&req) {
auto res = detail::make_unique<Response>();
auto error = Error::Success;
auto ret = send(req, *res, error);
return Result{ret ? std::move(res) : nullptr, error};
return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)};
}
inline bool ClientImpl::handle_request(Stream &strm, const Request &req,
inline bool ClientImpl::handle_request(Stream &strm, Request &req,
Response &res, bool close_connection,
Error &error) {
if (req.path.empty()) {
@ -5350,12 +5389,16 @@ inline bool ClientImpl::handle_request(Stream &strm, const Request &req,
return false;
}
auto req_save = req;
bool ret;
if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) {
auto req2 = req;
req2.path = "http://" + host_and_port_ + req.path;
ret = process_request(strm, req2, res, close_connection, error);
req = req2;
req.path = req_save.path;
} else {
ret = process_request(strm, req, res, close_connection, error);
}
@ -5363,6 +5406,7 @@ inline bool ClientImpl::handle_request(Stream &strm, const Request &req,
if (!ret) { return false; }
if (300 < res.status && res.status < 400 && follow_location_) {
req = req_save;
ret = redirect(req, res, error);
}
@ -5398,8 +5442,7 @@ inline bool ClientImpl::handle_request(Stream &strm, const Request &req,
return ret;
}
inline bool ClientImpl::redirect(const Request &req, Response &res,
Error &error) {
inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) {
if (req.redirect_count_ == 0) {
error = Error::ExceedRedirectCount;
return false;
@ -5476,75 +5519,74 @@ inline bool ClientImpl::write_content_with_provider(Stream &strm,
}
} // namespace httplib
inline bool ClientImpl::write_request(Stream &strm, const Request &req,
inline bool ClientImpl::write_request(Stream &strm, Request &req,
bool close_connection, Error &error) {
// Prepare additional headers
Headers headers;
if (close_connection) { headers.emplace("Connection", "close"); }
if (close_connection) { req.headers.emplace("Connection", "close"); }
if (!req.has_header("Host")) {
if (is_ssl()) {
if (port_ == 443) {
headers.emplace("Host", host_);
req.headers.emplace("Host", host_);
} else {
headers.emplace("Host", host_and_port_);
req.headers.emplace("Host", host_and_port_);
}
} else {
if (port_ == 80) {
headers.emplace("Host", host_);
req.headers.emplace("Host", host_);
} else {
headers.emplace("Host", host_and_port_);
req.headers.emplace("Host", host_and_port_);
}
}
}
if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); }
if (!req.has_header("Accept")) { req.headers.emplace("Accept", "*/*"); }
if (!req.has_header("User-Agent")) {
headers.emplace("User-Agent", "cpp-httplib/0.7");
req.headers.emplace("User-Agent", "cpp-httplib/0.7");
}
if (req.body.empty()) {
if (req.content_provider_) {
if (!req.is_chunked_content_provider_) {
auto length = std::to_string(req.content_length_);
headers.emplace("Content-Length", length);
req.headers.emplace("Content-Length", length);
}
} else {
if (req.method == "POST" || req.method == "PUT" ||
req.method == "PATCH") {
headers.emplace("Content-Length", "0");
req.headers.emplace("Content-Length", "0");
}
}
} else {
if (!req.has_header("Content-Type")) {
headers.emplace("Content-Type", "text/plain");
req.headers.emplace("Content-Type", "text/plain");
}
if (!req.has_header("Content-Length")) {
auto length = std::to_string(req.body.size());
headers.emplace("Content-Length", length);
req.headers.emplace("Content-Length", length);
}
}
if (!basic_auth_password_.empty()) {
headers.insert(make_basic_authentication_header(
req.headers.insert(make_basic_authentication_header(
basic_auth_username_, basic_auth_password_, false));
}
if (!proxy_basic_auth_username_.empty() &&
!proxy_basic_auth_password_.empty()) {
headers.insert(make_basic_authentication_header(
req.headers.insert(make_basic_authentication_header(
proxy_basic_auth_username_, proxy_basic_auth_password_, true));
}
if (!bearer_token_auth_token_.empty()) {
headers.insert(make_bearer_token_authentication_header(
req.headers.insert(make_bearer_token_authentication_header(
bearer_token_auth_token_, false));
}
if (!proxy_bearer_token_auth_token_.empty()) {
headers.insert(make_bearer_token_authentication_header(
req.headers.insert(make_bearer_token_authentication_header(
proxy_bearer_token_auth_token_, true));
}
@ -5555,7 +5597,7 @@ inline bool ClientImpl::write_request(Stream &strm, const Request &req,
const auto &path = detail::encode_url(req.path);
bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str());
detail::write_headers(bstrm, req, headers);
detail::write_headers(bstrm, req.headers);
// Flush buffer
auto &data = bstrm.get_buffer();
@ -5576,16 +5618,16 @@ inline bool ClientImpl::write_request(Stream &strm, const Request &req,
}
inline std::unique_ptr<Response> ClientImpl::send_with_content_provider(
const char *method, const char *path, const Headers &headers,
Request &req,
// const char *method, const char *path, const Headers &headers,
const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const char *content_type, Error &error) {
Request req;
req.method = method;
req.headers = default_headers_;
req.headers.insert(headers.begin(), headers.end());
req.path = path;
// Request req;
// req.method = method;
// req.headers = headers;
// req.path = path;
if (content_type) { req.headers.emplace("Content-Type", content_type); }
@ -5667,14 +5709,23 @@ inline Result ClientImpl::send_with_content_provider(
const char *body, size_t content_length, ContentProvider content_provider,
ContentProviderWithoutLength content_provider_without_length,
const char *content_type) {
Request req;
req.method = method;
req.headers = headers;
req.path = path;
auto error = Error::Success;
auto res = send_with_content_provider(
method, path, headers, body, content_length, std::move(content_provider),
req,
// method, path, headers,
body, content_length, std::move(content_provider),
std::move(content_provider_without_length), content_type, error);
return Result{std::move(res), error};
return Result{std::move(res), error, std::move(req.headers)};
}
inline bool ClientImpl::process_request(Stream &strm, const Request &req,
inline bool ClientImpl::process_request(Stream &strm, Request &req,
Response &res, bool close_connection,
Error &error) {
// Send request
@ -5687,8 +5738,8 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
return false;
}
if (req.response_handler_) {
if (!req.response_handler_(res)) {
if (req.response_handler) {
if (!req.response_handler(res)) {
error = Error::Canceled;
return false;
}
@ -5697,10 +5748,10 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
// Body
if ((res.status != 204) && req.method != "HEAD" && req.method != "CONNECT") {
auto out =
req.content_receiver_
req.content_receiver
? static_cast<ContentReceiverWithProgress>(
[&](const char *buf, size_t n, uint64_t off, uint64_t len) {
auto ret = req.content_receiver_(buf, n, off, len);
auto ret = req.content_receiver(buf, n, off, len);
if (!ret) { error = Error::Canceled; }
return ret;
})
@ -5715,8 +5766,8 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
});
auto progress = [&](uint64_t current, uint64_t total) {
if (!req.progress_) { return true; }
auto ret = req.progress_(current, total);
if (!req.progress) { return true; }
auto ret = req.progress(current, total);
if (!ret) { error = Error::Canceled; }
return ret;
};
@ -5778,11 +5829,10 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers,
Request req;
req.method = "GET";
req.path = path;
req.headers = default_headers_;
req.headers.insert(headers.begin(), headers.end());
req.progress_ = std::move(progress);
req.headers = headers;
req.progress = std::move(progress);
return send(req);
return send_(std::move(req));
}
inline Result ClientImpl::Get(const char *path,
@ -5838,17 +5888,16 @@ inline Result ClientImpl::Get(const char *path, const Headers &headers,
Request req;
req.method = "GET";
req.path = path;
req.headers = default_headers_;
req.headers.insert(headers.begin(), headers.end());
req.response_handler_ = std::move(response_handler);
req.content_receiver_ =
req.headers = headers;
req.response_handler = std::move(response_handler);
req.content_receiver =
[content_receiver](const char *data, size_t data_length,
uint64_t /*offset*/, uint64_t /*total_length*/) {
return content_receiver(data, data_length);
};
req.progress_ = std::move(progress);
req.progress = std::move(progress);
return send(req);
return send_(std::move(req));
}
inline Result ClientImpl::Get(const char *path, const Params &params,
@ -5887,11 +5936,10 @@ inline Result ClientImpl::Head(const char *path) {
inline Result ClientImpl::Head(const char *path, const Headers &headers) {
Request req;
req.method = "HEAD";
req.headers = default_headers_;
req.headers.insert(headers.begin(), headers.end());
req.headers = headers;
req.path = path;
return send(req);
return send_(std::move(req));
}
inline Result ClientImpl::Post(const char *path) {
@ -6151,14 +6199,13 @@ inline Result ClientImpl::Delete(const char *path, const Headers &headers,
const char *content_type) {
Request req;
req.method = "DELETE";
req.headers = default_headers_;
req.headers.insert(headers.begin(), headers.end());
req.headers = headers;
req.path = path;
if (content_type) { req.headers.emplace("Content-Type", content_type); }
req.body.assign(body, content_length);
return send(req);
return send_(std::move(req));
}
inline Result ClientImpl::Delete(const char *path, const std::string &body,
@ -6179,11 +6226,10 @@ inline Result ClientImpl::Options(const char *path) {
inline Result ClientImpl::Options(const char *path, const Headers &headers) {
Request req;
req.method = "OPTIONS";
req.headers = default_headers_;
req.headers.insert(headers.begin(), headers.end());
req.headers = headers;
req.path = path;
return send(req);
return send_(std::move(req));
}
inline size_t ClientImpl::is_socket_open() const {
@ -7303,7 +7349,7 @@ inline Result Client::Options(const char *path, const Headers &headers) {
return cli_->Options(path, headers);
}
inline bool Client::send(const Request &req, Response &res, Error &error) {
inline bool Client::send(Request &req, Response &res, Error &error) {
return cli_->send(req, res, error);
}

View file

@ -5,8 +5,8 @@
#include <atomic>
#include <chrono>
#include <future>
#include <thread>
#include <stdexcept>
#include <thread>
#define SERVER_CERT_FILE "./cert.pem"
#define SERVER_CERT2_FILE "./cert2.pem"
@ -549,12 +549,11 @@ TEST(ConnectionErrorTest, InvalidHost2) {
TEST(ConnectionErrorTest, InvalidPort) {
auto host = "localhost";
auto port = 44380;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
auto port = 44380;
SSLClient cli(host, port);
#else
auto port = 8080;
Client cli(host, port);
#endif
cli.set_connection_timeout(2);
@ -982,11 +981,12 @@ TEST(ErrorHandlerTest, ContentLength) {
TEST(ExceptionHandlerTest, ContentLength) {
Server svr;
svr.set_exception_handler([](const Request & /*req*/, Response &res, std::exception & /*e*/) {
res.status = 500;
res.set_content("abcdefghijklmnopqrstuvwxyz",
"text/html"); // <= Content-Length still 13
});
svr.set_exception_handler(
[](const Request & /*req*/, Response &res, std::exception & /*e*/) {
res.status = 500;
res.set_content("abcdefghijklmnopqrstuvwxyz",
"text/html"); // <= Content-Length still 13
});
svr.Get("/hi", [](const Request & /*req*/, Response &res) {
res.set_content("Hello World!\n", "text/plain");
@ -2614,6 +2614,24 @@ TEST_F(ServerTest, PutLargeFileWithGzip) {
EXPECT_EQ(LARGE_DATA, res->body);
}
TEST_F(ServerTest, PutLargeFileWithGzip2) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
Client cli("https://localhost:1234");
cli.enable_server_certificate_verification(false);
#else
Client cli("http://localhost:1234");
#endif
cli.set_compress(true);
auto res = cli.Put("/put-large", LARGE_DATA, "text/plain");
ASSERT_TRUE(res);
EXPECT_EQ(200, res->status);
EXPECT_EQ(LARGE_DATA, res->body);
EXPECT_EQ(101942u, res.get_request_header_value<uint64_t>("Content-Length"));
EXPECT_EQ("gzip", res.get_request_header_value("Content-Encoding"));
}
TEST_F(ServerTest, PutContentWithDeflate) {
cli_.set_compress(false);
Headers headers;
@ -3405,7 +3423,8 @@ TEST(ExceptionTest, ThrowExceptionInHandler) {
auto res = cli.Get("/hi");
ASSERT_TRUE(res);
EXPECT_EQ(500, res->status);
ASSERT_FALSE(res->has_header("EXCEPTION_WHAT"));
ASSERT_TRUE(res->has_header("EXCEPTION_WHAT"));
EXPECT_EQ("exception...", res->get_header_value("EXCEPTION_WHAT"));
svr.stop();
listen_thread.join();
@ -3963,7 +3982,6 @@ TEST(NoSSLSupport, SimpleInterface) {
}
#endif
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(InvalidScheme, SimpleInterface) {
ASSERT_ANY_THROW(Client cli("scheme://yahoo.com"));
}
@ -3973,6 +3991,19 @@ TEST(NoScheme, SimpleInterface) {
ASSERT_TRUE(cli.is_valid());
}
TEST(SendAPI, SimpleInterface) {
Client cli("http://yahoo.com");
Request req;
req.method = "GET";
req.path = "/";
auto res = cli.send(req);
ASSERT_TRUE(res);
EXPECT_EQ(301, res->status);
}
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(YahooRedirectTest2, SimpleInterface) {
Client cli("http://yahoo.com");