diff --git a/httplib.h b/httplib.h index 0b790e7..24a7c02 100644 --- a/httplib.h +++ b/httplib.h @@ -7228,63 +7228,62 @@ inline bool SSLSocketStream::is_writable() const { } inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { - size_t readbytes = 0; if (SSL_pending(ssl_) > 0) { - auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes); - if (ret == 1) { return static_cast(readbytes); } - if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; } - return -1; - } - if (!is_readable()) { return -1; } - - auto ret = SSL_read_ex(ssl_, ptr, size, &readbytes); - if (ret == 1) { return static_cast(readbytes); } - auto err = SSL_get_error(ssl_, ret); - int n = 1000; + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + int n = 1000; #ifdef _WIN32 - while (--n >= 0 && - (err == SSL_ERROR_WANT_READ || - (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { + while (--n >= 0 && (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { #else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { #endif - if (SSL_pending(ssl_) > 0) { - ret = SSL_read_ex(ssl_, ptr, size, &readbytes); - if (ret == 1) { return static_cast(readbytes); } - if (SSL_get_error(ssl_, ret) == SSL_ERROR_ZERO_RETURN) { return 0; } - return -1; + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } } - if (!is_readable()) { return -1; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - ret = SSL_read_ex(ssl_, ptr, size, &readbytes); - if (ret == 1) { return static_cast(readbytes); } - err = SSL_get_error(ssl_, ret); + return ret; } - if (err == SSL_ERROR_ZERO_RETURN) { return 0; } return -1; } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } - size_t written = 0; - auto ret = SSL_write_ex(ssl_, ptr, size, &written); - if (ret == 1) { return static_cast(written); } - auto err = SSL_get_error(ssl_, ret); - int n = 1000; + if (is_writable()) { + auto ret = SSL_write(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + int n = 1000; #ifdef _WIN32 - while (--n >= 0 && - (err == SSL_ERROR_WANT_WRITE || - (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { + while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { #else - while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif - if (!is_writable()) { return -1; } - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - ret = SSL_write_ex(ssl_, ptr, size, &written); - if (ret == 1) { return static_cast(written); } - err = SSL_get_error(ssl_, ret); + if (is_writable()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_write(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; } - if (err == SSL_ERROR_ZERO_RETURN) { return 0; } return -1; } diff --git a/test/test.cc b/test/test.cc index 322b208..e0cf90c 100644 --- a/test/test.cc +++ b/test/test.cc @@ -4660,50 +4660,6 @@ TEST(SSLClientServerTest, CustomizeServerSSLCtx) { t.join(); } - -// Disabled due to the out-of-memory problem on GitHub Actions Workflows -TEST(SSLClientServerTest, DISABLED_LargeDataTransfer) { - - // prepare large data - std::random_device seed_gen; - std::mt19937 random(seed_gen()); - constexpr auto large_size_byte = 2147483648UL + 1048576UL; // 2GiB + 1MiB - std::vector binary(large_size_byte / sizeof(std::uint32_t)); - std::generate(binary.begin(), binary.end(), [&random]() { return random(); }); - - // server - SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE); - ASSERT_TRUE(svr.is_valid()); - - svr.Post("/binary", [&](const Request &req, Response &res) { - EXPECT_EQ(large_size_byte, req.body.size()); - EXPECT_EQ(0, std::memcmp(binary.data(), req.body.data(), large_size_byte)); - res.set_content(req.body, "application/octet-stream"); - }); - - auto listen_thread = std::thread([&svr]() { svr.listen("localhost", PORT); }); - while (!svr.is_running()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - - // client POST - SSLClient cli("localhost", PORT); - cli.enable_server_certificate_verification(false); - cli.set_read_timeout(std::chrono::seconds(100)); - cli.set_write_timeout(std::chrono::seconds(100)); - auto res = cli.Post("/binary", reinterpret_cast(binary.data()), - large_size_byte, "application/octet-stream"); - - // compare - EXPECT_EQ(200, res->status); - EXPECT_EQ(large_size_byte, res->body.size()); - EXPECT_EQ(0, std::memcmp(binary.data(), res->body.data(), large_size_byte)); - - // cleanup - svr.stop(); - listen_thread.join(); - ASSERT_FALSE(svr.is_running()); -} #endif #ifdef _WIN32