Fix multiple threading bugs including #699 and #697

This commit is contained in:
David Wu 2020-10-16 21:34:51 +00:00 committed by yhirose
parent 47e5af15ea
commit 02d3cd5909
3 changed files with 178 additions and 49 deletions

View file

@ -645,6 +645,12 @@ cli.set_ca_cert_path("./ca-bundle.crt");
cli.enable_server_certificate_verification(true); cli.enable_server_certificate_verification(true);
``` ```
Note: When using SSL, it seems impossible to avoid SIGPIPE in all cases, since on some operating systems, SIGPIPE
can only be suppressed on a per-message basis, but there is no way to make the OpenSSL library do so for its
internal communications. If your program needs to avoid being terminated on SIGPIPE, the only fully general way might
be to set up a signal handler for SIGPIPE to handle or ignore it yourself.
Compression Compression
----------- -----------

218
httplib.h
View file

@ -932,7 +932,21 @@ protected:
}; };
virtual bool create_and_connect_socket(Socket &socket); virtual bool create_and_connect_socket(Socket &socket);
virtual void close_socket(Socket &socket, bool process_socket_ret);
// All of:
// shutdown_ssl
// shutdown_socket
// close_socket
// should ONLY be called when socket_mutex_ is locked.
// Also, shutdown_ssl and close_socket should also NOT be called concurrently
// with a DIFFERENT thread sending requests using that socket.
virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully);
void shutdown_socket(Socket &socket);
void close_socket(Socket &socket);
// Similar to shutdown_ssl and close_socket, this should NOT be called
// concurrently with a DIFFERENT thread sending requests from the socket
void lock_socket_and_shutdown_and_close();
bool process_request(Stream &strm, const Request &req, Response &res, bool process_request(Stream &strm, const Request &req, Response &res,
bool close_connection); bool close_connection);
@ -943,7 +957,7 @@ protected:
void copy_settings(const ClientImpl &rhs); void copy_settings(const ClientImpl &rhs);
// Error state // Error state
mutable Error error_ = Error::Success; mutable std::atomic<Error> error_;
// Socket endoint information // Socket endoint information
const std::string host_; const std::string host_;
@ -955,6 +969,11 @@ protected:
mutable std::mutex socket_mutex_; mutable std::mutex socket_mutex_;
std::recursive_mutex request_mutex_; std::recursive_mutex request_mutex_;
// These are all protected under socket_mutex
int socket_requests_in_flight_ = 0;
std::thread::id socket_requests_are_from_thread_ = std::thread::id();
bool socket_should_be_closed_when_request_is_done_ = false;
// Default headers // Default headers
Headers default_headers_; Headers default_headers_;
@ -1012,7 +1031,6 @@ private:
bool redirect(const Request &req, Response &res); bool redirect(const Request &req, Response &res);
bool handle_request(Stream &strm, const Request &req, Response &res, bool handle_request(Stream &strm, const Request &req, Response &res,
bool close_connection); bool close_connection);
void stop_core();
std::unique_ptr<Response> send_with_content_provider( std::unique_ptr<Response> send_with_content_provider(
const char *method, const char *path, const Headers &headers, const char *method, const char *path, const Headers &headers,
const std::string &body, size_t content_length, const std::string &body, size_t content_length,
@ -1020,7 +1038,8 @@ private:
ContentProviderWithoutLength content_provider_without_length, ContentProviderWithoutLength content_provider_without_length,
const char *content_type); const char *content_type);
virtual bool process_socket(Socket &socket, // socket is const because this function is called when socket_mutex_ is not locked
virtual bool process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback); std::function<bool(Stream &strm)> callback);
virtual bool is_ssl() const; virtual bool is_ssl() const;
}; };
@ -1243,9 +1262,9 @@ public:
private: private:
bool create_and_connect_socket(Socket &socket) override; bool create_and_connect_socket(Socket &socket) override;
void close_socket(Socket &socket, bool process_socket_ret) override; void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override;
bool process_socket(Socket &socket, bool process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback) override; std::function<bool(Stream &strm)> callback) override;
bool is_ssl() const override; bool is_ssl() const override;
@ -2046,7 +2065,7 @@ inline socket_t create_client_socket(const char *host, int port,
bool tcp_nodelay, bool tcp_nodelay,
SocketOptions socket_options, SocketOptions socket_options,
time_t timeout_sec, time_t timeout_usec, time_t timeout_sec, time_t timeout_usec,
const std::string &intf, Error &error) { const std::string &intf, std::atomic<Error> &error) {
auto sock = create_socket( auto sock = create_socket(
host, port, 0, tcp_nodelay, std::move(socket_options), host, port, 0, tcp_nodelay, std::move(socket_options),
[&](socket_t sock, struct addrinfo &ai) -> bool { [&](socket_t sock, struct addrinfo &ai) -> bool {
@ -4793,11 +4812,11 @@ inline ClientImpl::ClientImpl(const std::string &host, int port)
inline ClientImpl::ClientImpl(const std::string &host, int port, inline ClientImpl::ClientImpl(const std::string &host, int port,
const std::string &client_cert_path, const std::string &client_cert_path,
const std::string &client_key_path) const std::string &client_key_path)
: host_(host), port_(port), : error_(Error::Success), host_(host), port_(port),
host_and_port_(host_ + ":" + std::to_string(port_)), host_and_port_(host_ + ":" + std::to_string(port_)),
client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} client_cert_path_(client_cert_path), client_key_path_(client_key_path) {}
inline ClientImpl::~ClientImpl() { stop_core(); } inline ClientImpl::~ClientImpl() { lock_socket_and_shutdown_and_close(); }
inline bool ClientImpl::is_valid() const { return true; } inline bool ClientImpl::is_valid() const { return true; }
@ -4858,15 +4877,47 @@ inline bool ClientImpl::create_and_connect_socket(Socket &socket) {
return true; return true;
} }
inline void ClientImpl::close_socket(Socket &socket, inline void ClientImpl::shutdown_ssl(Socket &socket, bool shutdown_gracefully) {
bool /*process_socket_ret*/) { (void)socket;
detail::close_socket(socket.sock); (void)shutdown_gracefully;
socket_.sock = INVALID_SOCKET; //If there are any requests in flight from threads other than us, then it's
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT //a thread-unsafe race because individual ssl* objects are not thread-safe.
socket_.ssl = nullptr; assert(socket_requests_in_flight_ == 0 ||
#endif socket_requests_are_from_thread_ == std::this_thread::get_id());
} }
inline void ClientImpl::shutdown_socket(Socket &socket) {
if (socket.sock == INVALID_SOCKET)
return;
detail::shutdown_socket(socket.sock);
}
inline void ClientImpl::close_socket(Socket &socket) {
// If there are requests in flight in another thread, usually closing
// the socket will be fine and they will simply receive an error when
// using the closed socket, but it is still a bug since rarely the OS
// may reassign the socket id to be used for a new socket, and then
// suddenly they will be operating on a live socket that is different
// than the one they intended!
assert(socket_requests_in_flight_ == 0 ||
socket_requests_are_from_thread_ == std::this_thread::get_id());
// It is also a bug if this happens while SSL is still active
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
assert(socket.ssl == nullptr);
#endif
if (socket.sock == INVALID_SOCKET)
return;
detail::close_socket(socket.sock);
socket.sock = INVALID_SOCKET;
}
inline void ClientImpl::lock_socket_and_shutdown_and_close() {
std::lock_guard<std::mutex> guard(socket_mutex_);
shutdown_ssl(socket_, true);
shutdown_socket(socket_);
close_socket(socket_);
}
inline bool ClientImpl::read_response_line(Stream &strm, Response &res) { inline bool ClientImpl::read_response_line(Stream &strm, Response &res) {
std::array<char, 2048> buf; std::array<char, 2048> buf;
@ -4901,11 +4952,23 @@ inline bool ClientImpl::send(const Request &req, Response &res) {
{ {
std::lock_guard<std::mutex> guard(socket_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
// Set this to false immediately - if it ever gets set to true by the end of the
// request, we know another thread instructed us to close the socket.
socket_should_be_closed_when_request_is_done_ = false;
auto is_alive = false; auto is_alive = false;
if (socket_.is_open()) { if (socket_.is_open()) {
is_alive = detail::select_write(socket_.sock, 0, 0) > 0; is_alive = detail::select_write(socket_.sock, 0, 0) > 0;
if (!is_alive) { close_socket(socket_, false); } if (!is_alive) {
// Attempt to avoid sigpipe by shutting down nongracefully if it seems like
// the other side has already closed the connection
// Also, there cannot be any requests in flight from other threads since we locked
// request_mutex_, so safe to close everything immediately
const bool shutdown_gracefully = false;
shutdown_ssl(socket_, shutdown_gracefully);
shutdown_socket(socket_);
close_socket(socket_);
}
} }
if (!is_alive) { if (!is_alive) {
@ -4926,15 +4989,38 @@ inline bool ClientImpl::send(const Request &req, Response &res) {
} }
#endif #endif
} }
// Mark the current socket as being in use so that it cannot be closed by anyone
// else while this request is ongoing, even though we will be releasing the mutex.
if (socket_requests_in_flight_ > 1) {
assert(socket_requests_are_from_thread_ == std::this_thread::get_id());
}
socket_requests_in_flight_ += 1;
socket_requests_are_from_thread_ = std::this_thread::get_id();
} }
auto close_connection = !keep_alive_; auto close_connection = !keep_alive_;
auto ret = process_socket(socket_, [&](Stream &strm) { auto ret = process_socket(socket_, [&](Stream &strm) {
return handle_request(strm, req, res, close_connection); return handle_request(strm, req, res, close_connection);
}); });
if (close_connection || !ret) { stop_core(); } //Briefly lock mutex in order to mark that a request is no longer ongoing
{
std::lock_guard<std::mutex> guard(socket_mutex_);
socket_requests_in_flight_ -= 1;
if (socket_requests_in_flight_ <= 0) {
assert(socket_requests_in_flight_ == 0);
socket_requests_are_from_thread_ = std::thread::id();
}
if (socket_should_be_closed_when_request_is_done_ ||
close_connection ||
!ret ) {
shutdown_ssl(socket_, true);
shutdown_socket(socket_);
close_socket(socket_);
}
}
if (!ret) { if (!ret) {
if (error_ == Error::Success) { error_ = Error::Unknown; } if (error_ == Error::Success) { error_ = Error::Unknown; }
@ -5320,7 +5406,16 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
if (res.get_header_value("Connection") == "close" || if (res.get_header_value("Connection") == "close" ||
(res.version == "HTTP/1.0" && res.reason != "Connection established")) { (res.version == "HTTP/1.0" && res.reason != "Connection established")) {
stop_core(); // TODO this requires a not-entirely-obvious chain of calls to be correct
// for this to be safe. Maybe a code refactor (such as moving this out to
// the send function and getting rid of the recursiveness of the mutex)
// could make this more obvious.
// This is safe to call because process_request is only called by handle_request
// which is only called by send, which locks the request mutex during the process.
// It would be a bug to call it from a different thread since it's a thread-safety
// issue to do these things to the socket if another thread is using the socket.
lock_socket_and_shutdown_and_close();
} }
// Log // Log
@ -5330,7 +5425,7 @@ inline bool ClientImpl::process_request(Stream &strm, const Request &req,
} }
inline bool inline bool
ClientImpl::process_socket(Socket &socket, ClientImpl::process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback) { std::function<bool(Stream &strm)> callback) {
return detail::process_client_socket( return detail::process_client_socket(
socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_,
@ -5706,18 +5801,27 @@ inline size_t ClientImpl::is_socket_open() const {
} }
inline void ClientImpl::stop() { inline void ClientImpl::stop() {
stop_core();
error_ = Error::Canceled;
}
inline void ClientImpl::stop_core() {
std::lock_guard<std::mutex> guard(socket_mutex_); std::lock_guard<std::mutex> guard(socket_mutex_);
if (socket_.is_open()) { // There is no guarantee that this doesn't get overwritten later, but set it so that
detail::shutdown_socket(socket_.sock); // there is a good chance that any threads stopping as a result pick up this error.
std::this_thread::sleep_for(std::chrono::milliseconds(1)); error_ = Error::Canceled;
close_socket(socket_, true);
std::this_thread::sleep_for(std::chrono::milliseconds(1)); // If there is anything ongoing right now, the ONLY thread-safe thing we can do
// is to shutdown_socket, so that threads using this socket suddenly discover
// they can't read/write any more and error out.
// Everything else (closing the socket, shutting ssl down) is unsafe because these
// actions are not thread-safe.
if (socket_requests_in_flight_ > 0) {
shutdown_socket(socket_);
// Aside from that, we set a flag for the socket to be closed when we're done.
socket_should_be_closed_when_request_is_done_ = true;
return;
} }
//Otherwise, sitll holding the mutex, we can shut everything down ourselves
shutdown_ssl(socket_, true);
shutdown_socket(socket_);
close_socket(socket_);
} }
inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) {
@ -5844,9 +5948,12 @@ inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex,
} }
inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl,
bool process_socket_ret) { bool shutdown_gracefully) {
if (process_socket_ret) { // sometimes we may want to skip this to try to avoid SIGPIPE if we know
SSL_shutdown(ssl); // shutdown only if not already closed by remote // the remote has closed the network connection
// Note that it is not always possible to avoid SIGPIPE, this is merely a best-efforts.
if (shutdown_gracefully) {
SSL_shutdown(ssl);
} }
std::lock_guard<std::mutex> guard(ctx_mutex); std::lock_guard<std::mutex> guard(ctx_mutex);
@ -6108,9 +6215,10 @@ inline bool SSLServer::process_and_close_socket(socket_t sock) {
[&](Request &req) { req.ssl = ssl; }); [&](Request &req) { req.ssl = ssl; });
}); });
detail::ssl_delete(ctx_mutex_, ssl, ret); // Shutdown gracefully if the result seemed successful, non-gracefully if the
detail::shutdown_socket(sock); // connection appeared to be closed.
detail::close_socket(sock); const bool shutdown_gracefully = ret;
detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully);
return ret; return ret;
} }
@ -6167,6 +6275,10 @@ inline SSLClient::SSLClient(const std::string &host, int port,
inline SSLClient::~SSLClient() { inline SSLClient::~SSLClient() {
if (ctx_) { SSL_CTX_free(ctx_); } if (ctx_) { SSL_CTX_free(ctx_); }
// Make sure to shut down SSL since shutdown_ssl will resolve to the
// base function rather than the derived function once we get to the
// base class destructor, and won't free the SSL (causing a leak).
SSLClient::shutdown_ssl(socket_, true);
} }
inline bool SSLClient::is_valid() const { return ctx_; } inline bool SSLClient::is_valid() const { return ctx_; }
@ -6200,11 +6312,11 @@ inline bool SSLClient::create_and_connect_socket(Socket &socket) {
return is_valid() && ClientImpl::create_and_connect_socket(socket); return is_valid() && ClientImpl::create_and_connect_socket(socket);
} }
// Assumes that socket_mutex_ is locked and that there are no requests in flight
inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
bool &success) { bool &success) {
success = true; success = true;
Response res2; Response res2;
if (!detail::process_client_socket( if (!detail::process_client_socket(
socket.sock, read_timeout_sec_, read_timeout_usec_, socket.sock, read_timeout_sec_, read_timeout_usec_,
write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) {
@ -6213,7 +6325,10 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
req2.path = host_and_port_; req2.path = host_and_port_;
return process_request(strm, req2, res2, false); return process_request(strm, req2, res2, false);
})) { })) {
close_socket(socket, true); // Thread-safe to close everything because we are assuming there are no requests in flight
shutdown_ssl(socket, true);
shutdown_socket(socket);
close_socket(socket);
success = false; success = false;
return false; return false;
} }
@ -6236,7 +6351,10 @@ inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res,
true)); true));
return process_request(strm, req3, res3, false); return process_request(strm, req3, res3, false);
})) { })) {
close_socket(socket, true); // Thread-safe to close everything because we are assuming there are no requests in flight
shutdown_ssl(socket, true);
shutdown_socket(socket);
close_socket(socket);
success = false; success = false;
return false; return false;
} }
@ -6331,21 +6449,25 @@ inline bool SSLClient::initialize_ssl(Socket &socket) {
return true; return true;
} }
close_socket(socket, false); shutdown_socket(socket);
close_socket(socket);
return false; return false;
} }
inline void SSLClient::close_socket(Socket &socket, bool process_socket_ret) { inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) {
detail::close_socket(socket.sock); if (socket.sock == INVALID_SOCKET) {
socket_.sock = INVALID_SOCKET; assert(socket.ssl == nullptr);
if (socket.ssl) { return;
detail::ssl_delete(ctx_mutex_, socket.ssl, process_socket_ret);
socket_.ssl = nullptr;
} }
if (socket.ssl) {
detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully);
socket.ssl = nullptr;
}
assert(socket.ssl == nullptr);
} }
inline bool inline bool
SSLClient::process_socket(Socket &socket, SSLClient::process_socket(const Socket &socket,
std::function<bool(Stream &strm)> callback) { std::function<bool(Stream &strm)> callback) {
assert(socket.ssl); assert(socket.ssl);
return detail::process_client_socket_ssl( return detail::process_client_socket_ssl(

View file

@ -5,6 +5,7 @@
#include <chrono> #include <chrono>
#include <future> #include <future>
#include <thread> #include <thread>
#include <atomic>
#define SERVER_CERT_FILE "./cert.pem" #define SERVER_CERT_FILE "./cert.pem"
#define SERVER_CERT2_FILE "./cert2.pem" #define SERVER_CERT2_FILE "./cert2.pem"
@ -2761,7 +2762,7 @@ TEST_F(ServerTest, Brotli) {
// Sends a raw request to a server listening at HOST:PORT. // Sends a raw request to a server listening at HOST:PORT.
static bool send_request(time_t read_timeout_sec, const std::string &req, static bool send_request(time_t read_timeout_sec, const std::string &req,
std::string *resp = nullptr) { std::string *resp = nullptr) {
Error error = Error::Success; std::atomic<Error> error(Error::Success);
auto client_sock = auto client_sock =
detail::create_client_socket(HOST, PORT, false, nullptr, detail::create_client_socket(HOST, PORT, false, nullptr,