Branch data Line data Source code
1 : : // Copyright (c) 2020-present The Bitcoin Core developers
2 : : // Distributed under the MIT software license, see the accompanying
3 : : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
4 : :
5 : : #include <util/sock.h>
6 : :
7 : : #include <compat/compat.h>
8 : : #include <span.h>
9 : : #include <tinyformat.h>
10 : : #include <util/check.h>
11 : : #include <util/log.h>
12 : : #include <util/syserror.h>
13 : : #include <util/threadinterrupt.h>
14 : : #include <util/time.h>
15 : :
16 : : #include <algorithm>
17 : : #include <compare>
18 : : #include <exception>
19 : : #include <memory>
20 : : #include <stdexcept>
21 : : #include <string>
22 : : #include <utility>
23 : : #include <vector>
24 : :
25 : : #ifdef USE_POLL
26 : : #include <poll.h>
27 : : #endif
28 : :
29 : 2444 : static inline bool IOErrorIsPermanent(int err)
30 : : {
31 : 2444 : return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
32 : : }
33 : :
34 : 136160 : Sock::Sock(SOCKET s) : m_socket(s) {}
35 : :
36 : 0 : Sock::Sock(Sock&& other)
37 : : {
38 : 0 : m_socket = other.m_socket;
39 : 0 : other.m_socket = INVALID_SOCKET;
40 : 0 : }
41 : :
42 : 136160 : Sock::~Sock() { Close(); }
43 : :
44 : 0 : Sock& Sock::operator=(Sock&& other)
45 : : {
46 : 0 : Close();
47 : 0 : m_socket = other.m_socket;
48 : 0 : other.m_socket = INVALID_SOCKET;
49 : 0 : return *this;
50 : : }
51 : :
52 : 0 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
53 : : {
54 : 0 : return send(m_socket, static_cast<const char*>(data), len, flags);
55 : : }
56 : :
57 : 0 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
58 : : {
59 [ # # ]: 0 : return recv(m_socket, static_cast<char*>(buf), len, flags);
60 : : }
61 : :
62 : 0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
63 : : {
64 : 0 : return connect(m_socket, addr, addr_len);
65 : : }
66 : :
67 : 0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
68 : : {
69 : 0 : return bind(m_socket, addr, addr_len);
70 : : }
71 : :
72 : 0 : int Sock::Listen(int backlog) const
73 : : {
74 : 0 : return listen(m_socket, backlog);
75 : : }
76 : :
77 : 0 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
78 : : {
79 : : #ifdef WIN32
80 : : static constexpr auto ERR = INVALID_SOCKET;
81 : : #else
82 : 0 : static constexpr auto ERR = SOCKET_ERROR;
83 : : #endif
84 : :
85 : 0 : std::unique_ptr<Sock> sock;
86 : :
87 [ # # ]: 0 : const auto socket = accept(m_socket, addr, addr_len);
88 [ # # ]: 0 : if (socket != ERR) {
89 : 0 : try {
90 [ # # ]: 0 : sock = std::make_unique<Sock>(socket);
91 [ - - ]: 0 : } catch (const std::exception&) {
92 : : #ifdef WIN32
93 : : closesocket(socket);
94 : : #else
95 [ - - ]: 0 : close(socket);
96 : : #endif
97 : 0 : }
98 : : }
99 : :
100 : 0 : return sock;
101 : 0 : }
102 : :
103 : 0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
104 : : {
105 : 0 : return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
106 : : }
107 : :
108 : 0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
109 : : {
110 : 0 : return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
111 : : }
112 : :
113 : 0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
114 : : {
115 : 0 : return getsockname(m_socket, name, name_len);
116 : : }
117 : :
118 : 0 : bool Sock::SetNonBlocking() const
119 : : {
120 : : #ifdef WIN32
121 : : u_long on{1};
122 : : if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
123 : : return false;
124 : : }
125 : : #else
126 : 0 : const int flags{fcntl(m_socket, F_GETFL, 0)};
127 [ # # ]: 0 : if (flags == SOCKET_ERROR) {
128 : : return false;
129 : : }
130 [ # # ]: 0 : if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
131 : 0 : return false;
132 : : }
133 : : #endif
134 : : return true;
135 : : }
136 : :
137 : 0 : bool Sock::IsSelectable() const
138 : : {
139 : : #if defined(USE_POLL) || defined(WIN32)
140 : 0 : return true;
141 : : #else
142 : : return m_socket < FD_SETSIZE;
143 : : #endif
144 : : }
145 : :
146 : 0 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
147 : : {
148 : : // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want
149 : : // `this` to be destroyed when the `shared_ptr` goes out of scope at the
150 : : // end of this function.
151 : : // Create it with an aliasing shared_ptr that points to `this` without
152 : : // owning it.
153 [ # # ]: 0 : std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this};
154 : :
155 [ # # # # : 0 : EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
# # # # #
# ]
156 : :
157 [ # # # # ]: 0 : if (!WaitMany(timeout, events_per_sock)) {
158 : : return false;
159 : : }
160 : :
161 [ # # ]: 0 : if (occurred != nullptr) {
162 : 0 : *occurred = events_per_sock.begin()->second.occurred;
163 : : }
164 : :
165 : : return true;
166 [ # # ]: 0 : }
167 : :
168 : 0 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
169 : : {
170 : : #ifdef USE_POLL
171 : 0 : std::vector<pollfd> pfds;
172 [ # # # # ]: 0 : for (const auto& [sock, events] : events_per_sock) {
173 [ # # ]: 0 : pfds.emplace_back();
174 : 0 : auto& pfd = pfds.back();
175 [ # # ]: 0 : pfd.fd = sock->m_socket;
176 [ # # ]: 0 : if (events.requested & RECV) {
177 : 0 : pfd.events |= POLLIN;
178 : : }
179 [ # # ]: 0 : if (events.requested & SEND) {
180 : 0 : pfd.events |= POLLOUT;
181 : : }
182 : : }
183 : :
184 [ # # # # : 0 : if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
# # ]
185 : : return false;
186 : : }
187 : :
188 [ # # # # ]: 0 : assert(pfds.size() == events_per_sock.size());
189 : 0 : size_t i{0};
190 [ # # # # ]: 0 : for (auto& [sock, events] : events_per_sock) {
191 [ # # ]: 0 : assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
192 : 0 : events.occurred = 0;
193 [ # # ]: 0 : if (pfds[i].revents & POLLIN) {
194 : 0 : events.occurred |= RECV;
195 : : }
196 [ # # ]: 0 : if (pfds[i].revents & POLLOUT) {
197 : 0 : events.occurred |= SEND;
198 : : }
199 [ # # ]: 0 : if (pfds[i].revents & (POLLERR | POLLHUP)) {
200 : 0 : events.occurred |= ERR;
201 : : }
202 : 0 : ++i;
203 : : }
204 : :
205 : : return true;
206 : : #else
207 : : fd_set recv;
208 : : fd_set send;
209 : : fd_set err;
210 : : FD_ZERO(&recv);
211 : : FD_ZERO(&send);
212 : : FD_ZERO(&err);
213 : : SOCKET socket_max{0};
214 : :
215 : : for (const auto& [sock, events] : events_per_sock) {
216 : : if (!sock->IsSelectable()) {
217 : : return false;
218 : : }
219 : : const auto& s = sock->m_socket;
220 : : if (events.requested & RECV) {
221 : : FD_SET(s, &recv);
222 : : }
223 : : if (events.requested & SEND) {
224 : : FD_SET(s, &send);
225 : : }
226 : : FD_SET(s, &err);
227 : : socket_max = std::max(socket_max, s);
228 : : }
229 : :
230 : : timeval tv = MillisToTimeval(timeout);
231 : :
232 : : if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
233 : : return false;
234 : : }
235 : :
236 : : for (auto& [sock, events] : events_per_sock) {
237 : : const auto& s = sock->m_socket;
238 : : events.occurred = 0;
239 : : if (FD_ISSET(s, &recv)) {
240 : : events.occurred |= RECV;
241 : : }
242 : : if (FD_ISSET(s, &send)) {
243 : : events.occurred |= SEND;
244 : : }
245 : : if (FD_ISSET(s, &err)) {
246 : : events.occurred |= ERR;
247 : : }
248 : : }
249 : :
250 : : return true;
251 : : #endif /* USE_POLL */
252 : 0 : }
253 : :
254 : 1766 : void Sock::SendComplete(std::span<const unsigned char> data,
255 : : std::chrono::milliseconds timeout,
256 : : CThreadInterrupt& interrupt) const
257 : : {
258 : 1766 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
259 : 1766 : size_t sent{0};
260 : :
261 : 12358 : for (;;) {
262 : 7062 : const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
263 : :
264 [ + + ]: 7062 : if (ret > 0) {
265 : 5306 : sent += static_cast<size_t>(ret);
266 [ + + ]: 5306 : if (sent == data.size()) {
267 : : break;
268 : : }
269 : : } else {
270 : 1756 : const int err{WSAGetLastError()};
271 [ + + ]: 1756 : if (IOErrorIsPermanent(err)) {
272 [ + - + - : 458 : throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
+ - ]
273 : : }
274 : : }
275 : :
276 : 5317 : const auto now = GetTime<std::chrono::milliseconds>();
277 : :
278 [ - + ]: 5317 : if (now >= deadline) {
279 [ # # ]: 0 : throw std::runtime_error(strprintf(
280 [ # # # # ]: 0 : "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
281 : : }
282 : :
283 [ + + ]: 5317 : if (interrupt) {
284 [ + - ]: 21 : throw std::runtime_error(strprintf(
285 [ + - + - ]: 42 : "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
286 : : }
287 : :
288 : : // Wait for a short while (or the socket to become ready for sending) before retrying
289 : : // if nothing was sent.
290 : 5296 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
291 : 5296 : (void)Wait(wait_time, SEND);
292 : 5296 : }
293 : 1516 : }
294 : :
295 : 1700 : void Sock::SendComplete(std::span<const char> data,
296 : : std::chrono::milliseconds timeout,
297 : : CThreadInterrupt& interrupt) const
298 : : {
299 : 1700 : SendComplete(MakeUCharSpan(data), timeout, interrupt);
300 : 1487 : }
301 : :
302 : 1576 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
303 : : std::chrono::milliseconds timeout,
304 : : CThreadInterrupt& interrupt,
305 : : size_t max_data) const
306 : : {
307 : 1576 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
308 : 1576 : std::string data;
309 : 1576 : bool terminator_found{false};
310 : :
311 : : // We must not consume any bytes past the terminator from the socket.
312 : : // One option is to read one byte at a time and check if we have read a terminator.
313 : : // However that is very slow. Instead, we peek at what is in the socket and only read
314 : : // as many bytes as possible without crossing the terminator.
315 : : // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
316 : : // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
317 : : // at a time is about 50 times slower.
318 : :
319 : 19376 : for (;;) {
320 [ - + + + ]: 19376 : if (data.size() >= max_data) {
321 : 15 : throw std::runtime_error(
322 [ - + + - : 30 : strprintf("Received too many bytes without a terminator (%u)", data.size()));
+ - ]
323 : : }
324 : :
325 : 19361 : char buf[512];
326 : :
327 [ + + + - ]: 38594 : const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
328 : :
329 [ + + + ]: 19361 : switch (peek_ret) {
330 : 688 : case -1: {
331 : 688 : const int err{WSAGetLastError()};
332 [ + + ]: 688 : if (IOErrorIsPermanent(err)) {
333 [ + - + - : 36 : throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
+ - ]
334 : : }
335 : : break;
336 : : }
337 : 62 : case 0:
338 [ + - ]: 62 : throw std::runtime_error("Connection unexpectedly closed by peer");
339 : 18611 : default:
340 : 18611 : auto end = buf + peek_ret;
341 : 18611 : auto terminator_pos = std::find(buf, end, terminator);
342 : 18611 : terminator_found = terminator_pos != end;
343 : :
344 [ + + ]: 18611 : const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
345 : : static_cast<size_t>(peek_ret)};
346 : :
347 [ + - ]: 18611 : const ssize_t read_ret{Recv(buf, try_len, 0)};
348 : :
349 [ + + + + ]: 18611 : if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
350 : 73 : throw std::runtime_error(
351 [ + - ]: 73 : strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
352 : : "peek claimed %u bytes are available",
353 [ + - ]: 219 : read_ret, try_len, peek_ret));
354 : : }
355 : :
356 : : // Don't include the terminator in the output.
357 [ + + ]: 18538 : const size_t append_len{terminator_found ? try_len - 1 : try_len};
358 : :
359 [ + - ]: 18538 : data.append(buf, buf + append_len);
360 : :
361 [ + + ]: 18538 : if (terminator_found) {
362 : 1386 : return data;
363 : : }
364 : : }
365 : :
366 : 17822 : const auto now = GetTime<std::chrono::milliseconds>();
367 : :
368 [ - + ]: 17822 : if (now >= deadline) {
369 : 0 : throw std::runtime_error(strprintf(
370 [ # # # # : 0 : "Receive timeout (received %u bytes without terminator before that)", data.size()));
# # ]
371 : : }
372 : :
373 [ + - + + ]: 17822 : if (interrupt) {
374 : 0 : throw std::runtime_error(strprintf(
375 : : "Receive interrupted (received %u bytes without terminator before that)",
376 [ - + + - : 44 : data.size()));
+ - ]
377 : : }
378 : :
379 : : // Wait for a short while (or the socket to become ready for reading) before retrying.
380 [ + - ]: 17800 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
381 [ + - ]: 17800 : (void)Wait(wait_time, RECV);
382 : : }
383 : 190 : }
384 : :
385 : 0 : bool Sock::IsConnected(std::string& errmsg) const
386 : : {
387 [ # # ]: 0 : if (m_socket == INVALID_SOCKET) {
388 : 0 : errmsg = "not connected";
389 : 0 : return false;
390 : : }
391 : :
392 : 0 : char c;
393 [ # # # ]: 0 : switch (Recv(&c, sizeof(c), MSG_PEEK)) {
394 : 0 : case -1: {
395 : 0 : const int err = WSAGetLastError();
396 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
397 : 0 : errmsg = NetworkErrorString(err);
398 : 0 : return false;
399 : : }
400 : : return true;
401 : : }
402 : 0 : case 0:
403 : 0 : errmsg = "closed";
404 : 0 : return false;
405 : : default:
406 : : return true;
407 : : }
408 : : }
409 : :
410 : 136160 : void Sock::Close()
411 : : {
412 [ - + ]: 136160 : if (m_socket == INVALID_SOCKET) {
413 : : return;
414 : : }
415 : : #ifdef WIN32
416 : : int ret = closesocket(m_socket);
417 : : #else
418 : 0 : int ret = close(m_socket);
419 : : #endif
420 [ # # ]: 0 : if (ret) {
421 [ # # ]: 0 : LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError()));
422 : : }
423 : 0 : m_socket = INVALID_SOCKET;
424 : : }
425 : :
426 : 0 : bool Sock::operator==(SOCKET s) const
427 : : {
428 : 0 : return m_socket == s;
429 : : };
430 : :
431 : 19656 : std::string NetworkErrorString(int err)
432 : : {
433 : : #if defined(WIN32)
434 : : return Win32ErrorString(err);
435 : : #else
436 : : // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
437 : 19656 : return SysErrorString(err);
438 : : #endif
439 : : }
|