Branch data Line data Source code
1 : : // Copyright (c) 2020-present The Bitcoin Core developers
2 : : // Distributed under the MIT software license, see the accompanying
3 : : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
4 : :
5 : : #include <util/sock.h>
6 : :
7 : : #include <common/system.h>
8 : : #include <compat/compat.h>
9 : : #include <span.h>
10 : : #include <tinyformat.h>
11 : : #include <util/log.h>
12 : : #include <util/syserror.h>
13 : : #include <util/threadinterrupt.h>
14 : : #include <util/time.h>
15 : :
16 : : #include <memory>
17 : : #include <stdexcept>
18 : : #include <string>
19 : :
20 : : #ifdef USE_POLL
21 : : #include <poll.h>
22 : : #endif
23 : :
24 : 2544 : static inline bool IOErrorIsPermanent(int err)
25 : : {
26 : 2544 : return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
27 : : }
28 : :
29 : 149984 : Sock::Sock(SOCKET s) : m_socket(s) {}
30 : :
31 : 0 : Sock::Sock(Sock&& other)
32 : : {
33 : 0 : m_socket = other.m_socket;
34 : 0 : other.m_socket = INVALID_SOCKET;
35 : 0 : }
36 : :
37 : 149984 : Sock::~Sock() { Close(); }
38 : :
39 : 0 : Sock& Sock::operator=(Sock&& other)
40 : : {
41 : 0 : Close();
42 : 0 : m_socket = other.m_socket;
43 : 0 : other.m_socket = INVALID_SOCKET;
44 : 0 : return *this;
45 : : }
46 : :
47 : 0 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
48 : : {
49 : 0 : return send(m_socket, static_cast<const char*>(data), len, flags);
50 : : }
51 : :
52 : 0 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
53 : : {
54 [ # # ]: 0 : return recv(m_socket, static_cast<char*>(buf), len, flags);
55 : : }
56 : :
57 : 0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
58 : : {
59 : 0 : return connect(m_socket, addr, addr_len);
60 : : }
61 : :
62 : 0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
63 : : {
64 : 0 : return bind(m_socket, addr, addr_len);
65 : : }
66 : :
67 : 0 : int Sock::Listen(int backlog) const
68 : : {
69 : 0 : return listen(m_socket, backlog);
70 : : }
71 : :
72 : 0 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
73 : : {
74 : : #ifdef WIN32
75 : : static constexpr auto ERR = INVALID_SOCKET;
76 : : #else
77 : 0 : static constexpr auto ERR = SOCKET_ERROR;
78 : : #endif
79 : :
80 : 0 : std::unique_ptr<Sock> sock;
81 : :
82 [ # # ]: 0 : const auto socket = accept(m_socket, addr, addr_len);
83 [ # # ]: 0 : if (socket != ERR) {
84 : 0 : try {
85 [ # # ]: 0 : sock = std::make_unique<Sock>(socket);
86 [ - - ]: 0 : } catch (const std::exception&) {
87 : : #ifdef WIN32
88 : : closesocket(socket);
89 : : #else
90 [ - - ]: 0 : close(socket);
91 : : #endif
92 : 0 : }
93 : : }
94 : :
95 : 0 : return sock;
96 : 0 : }
97 : :
98 : 0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
99 : : {
100 : 0 : return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
101 : : }
102 : :
103 : 0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
104 : : {
105 : 0 : return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
106 : : }
107 : :
108 : 0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
109 : : {
110 : 0 : return getsockname(m_socket, name, name_len);
111 : : }
112 : :
113 : 0 : bool Sock::SetNonBlocking() const
114 : : {
115 : : #ifdef WIN32
116 : : u_long on{1};
117 : : if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
118 : : return false;
119 : : }
120 : : #else
121 : 0 : const int flags{fcntl(m_socket, F_GETFL, 0)};
122 [ # # ]: 0 : if (flags == SOCKET_ERROR) {
123 : : return false;
124 : : }
125 [ # # ]: 0 : if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
126 : 0 : return false;
127 : : }
128 : : #endif
129 : : return true;
130 : : }
131 : :
132 : 0 : bool Sock::IsSelectable() const
133 : : {
134 : : #if defined(USE_POLL) || defined(WIN32)
135 : 0 : return true;
136 : : #else
137 : : return m_socket < FD_SETSIZE;
138 : : #endif
139 : : }
140 : :
141 : 0 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
142 : : {
143 : : // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
144 : : // `this` to be destroyed when the `shared_ptr` goes out of scope at the
145 : : // end of this function. Create it with a custom noop deleter.
146 : 0 : std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
147 : :
148 [ # # # # : 0 : EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
# # # # #
# # # ]
149 : :
150 [ # # # # ]: 0 : if (!WaitMany(timeout, events_per_sock)) {
151 : : return false;
152 : : }
153 : :
154 [ # # ]: 0 : if (occurred != nullptr) {
155 : 0 : *occurred = events_per_sock.begin()->second.occurred;
156 : : }
157 : :
158 : : return true;
159 [ # # # # ]: 0 : }
160 : :
161 : 0 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
162 : : {
163 : : #ifdef USE_POLL
164 : 0 : std::vector<pollfd> pfds;
165 [ # # # # ]: 0 : for (const auto& [sock, events] : events_per_sock) {
166 [ # # ]: 0 : pfds.emplace_back();
167 : 0 : auto& pfd = pfds.back();
168 [ # # ]: 0 : pfd.fd = sock->m_socket;
169 [ # # ]: 0 : if (events.requested & RECV) {
170 : 0 : pfd.events |= POLLIN;
171 : : }
172 [ # # ]: 0 : if (events.requested & SEND) {
173 : 0 : pfd.events |= POLLOUT;
174 : : }
175 : : }
176 : :
177 [ # # # # : 0 : if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
# # ]
178 : : return false;
179 : : }
180 : :
181 [ # # # # ]: 0 : assert(pfds.size() == events_per_sock.size());
182 : 0 : size_t i{0};
183 [ # # # # ]: 0 : for (auto& [sock, events] : events_per_sock) {
184 [ # # ]: 0 : assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
185 : 0 : events.occurred = 0;
186 [ # # ]: 0 : if (pfds[i].revents & POLLIN) {
187 : 0 : events.occurred |= RECV;
188 : : }
189 [ # # ]: 0 : if (pfds[i].revents & POLLOUT) {
190 : 0 : events.occurred |= SEND;
191 : : }
192 [ # # ]: 0 : if (pfds[i].revents & (POLLERR | POLLHUP)) {
193 : 0 : events.occurred |= ERR;
194 : : }
195 : 0 : ++i;
196 : : }
197 : :
198 : : return true;
199 : : #else
200 : : fd_set recv;
201 : : fd_set send;
202 : : fd_set err;
203 : : FD_ZERO(&recv);
204 : : FD_ZERO(&send);
205 : : FD_ZERO(&err);
206 : : SOCKET socket_max{0};
207 : :
208 : : for (const auto& [sock, events] : events_per_sock) {
209 : : if (!sock->IsSelectable()) {
210 : : return false;
211 : : }
212 : : const auto& s = sock->m_socket;
213 : : if (events.requested & RECV) {
214 : : FD_SET(s, &recv);
215 : : }
216 : : if (events.requested & SEND) {
217 : : FD_SET(s, &send);
218 : : }
219 : : FD_SET(s, &err);
220 : : socket_max = std::max(socket_max, s);
221 : : }
222 : :
223 : : timeval tv = MillisToTimeval(timeout);
224 : :
225 : : if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
226 : : return false;
227 : : }
228 : :
229 : : for (auto& [sock, events] : events_per_sock) {
230 : : const auto& s = sock->m_socket;
231 : : events.occurred = 0;
232 : : if (FD_ISSET(s, &recv)) {
233 : : events.occurred |= RECV;
234 : : }
235 : : if (FD_ISSET(s, &send)) {
236 : : events.occurred |= SEND;
237 : : }
238 : : if (FD_ISSET(s, &err)) {
239 : : events.occurred |= ERR;
240 : : }
241 : : }
242 : :
243 : : return true;
244 : : #endif /* USE_POLL */
245 : 0 : }
246 : :
247 : 2758 : void Sock::SendComplete(std::span<const unsigned char> data,
248 : : std::chrono::milliseconds timeout,
249 : : CThreadInterrupt& interrupt) const
250 : : {
251 : 2758 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
252 : 2758 : size_t sent{0};
253 : :
254 : 19678 : for (;;) {
255 : 11218 : const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
256 : :
257 [ + + ]: 11218 : if (ret > 0) {
258 : 9172 : sent += static_cast<size_t>(ret);
259 [ + + ]: 9172 : if (sent == data.size()) {
260 : : break;
261 : : }
262 : : } else {
263 : 2046 : const int err{WSAGetLastError()};
264 [ + + ]: 2046 : if (IOErrorIsPermanent(err)) {
265 [ + - + - : 688 : throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
+ - ]
266 : : }
267 : : }
268 : :
269 : 8519 : const auto now = GetTime<std::chrono::milliseconds>();
270 : :
271 [ - + ]: 8519 : if (now >= deadline) {
272 [ # # ]: 0 : throw std::runtime_error(strprintf(
273 [ # # # # ]: 0 : "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
274 : : }
275 : :
276 [ + + ]: 8519 : if (interrupt) {
277 [ + - ]: 59 : throw std::runtime_error(strprintf(
278 [ + - + - ]: 118 : "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
279 : : }
280 : :
281 : : // Wait for a short while (or the socket to become ready for sending) before retrying
282 : : // if nothing was sent.
283 : 8460 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
284 : 8460 : (void)Wait(wait_time, SEND);
285 : 8460 : }
286 : 2355 : }
287 : :
288 : 2670 : void Sock::SendComplete(std::span<const char> data,
289 : : std::chrono::milliseconds timeout,
290 : : CThreadInterrupt& interrupt) const
291 : : {
292 : 2670 : SendComplete(MakeUCharSpan(data), timeout, interrupt);
293 : 2316 : }
294 : :
295 : 2444 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
296 : : std::chrono::milliseconds timeout,
297 : : CThreadInterrupt& interrupt,
298 : : size_t max_data) const
299 : : {
300 : 2444 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
301 : 2444 : std::string data;
302 : 2444 : bool terminator_found{false};
303 : :
304 : : // We must not consume any bytes past the terminator from the socket.
305 : : // One option is to read one byte at a time and check if we have read a terminator.
306 : : // However that is very slow. Instead, we peek at what is in the socket and only read
307 : : // as many bytes as possible without crossing the terminator.
308 : : // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
309 : : // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
310 : : // at a time is about 50 times slower.
311 : :
312 : 25931 : for (;;) {
313 [ - + + + ]: 25931 : if (data.size() >= max_data) {
314 : 21 : throw std::runtime_error(
315 [ - + + - : 42 : strprintf("Received too many bytes without a terminator (%u)", data.size()));
+ - ]
316 : : }
317 : :
318 : 25910 : char buf[512];
319 : :
320 [ + + + - ]: 51712 : const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
321 : :
322 [ + + + ]: 25910 : switch (peek_ret) {
323 : 498 : case -1: {
324 : 498 : const int err{WSAGetLastError()};
325 [ + + ]: 498 : if (IOErrorIsPermanent(err)) {
326 [ + - + - : 128 : throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
+ - ]
327 : : }
328 : : break;
329 : : }
330 : 83 : case 0:
331 [ + - ]: 83 : throw std::runtime_error("Connection unexpectedly closed by peer");
332 : 25329 : default:
333 : 25329 : auto end = buf + peek_ret;
334 : 25329 : auto terminator_pos = std::find(buf, end, terminator);
335 : 25329 : terminator_found = terminator_pos != end;
336 : :
337 [ + + ]: 25329 : const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
338 : : static_cast<size_t>(peek_ret)};
339 : :
340 [ + - ]: 25329 : const ssize_t read_ret{Recv(buf, try_len, 0)};
341 : :
342 [ + + + + ]: 25329 : if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
343 : 148 : throw std::runtime_error(
344 [ + - ]: 148 : strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
345 : : "peek claimed %u bytes are available",
346 [ + - ]: 444 : read_ret, try_len, peek_ret));
347 : : }
348 : :
349 : : // Don't include the terminator in the output.
350 [ + + ]: 25181 : const size_t append_len{terminator_found ? try_len - 1 : try_len};
351 : :
352 [ + - ]: 25181 : data.append(buf, buf + append_len);
353 : :
354 [ + + ]: 25181 : if (terminator_found) {
355 : 2072 : return data;
356 : : }
357 : : }
358 : :
359 : 23543 : const auto now = GetTime<std::chrono::milliseconds>();
360 : :
361 [ - + ]: 23543 : if (now >= deadline) {
362 : 0 : throw std::runtime_error(strprintf(
363 [ # # # # : 0 : "Receive timeout (received %u bytes without terminator before that)", data.size()));
# # ]
364 : : }
365 : :
366 [ + - + + ]: 23543 : if (interrupt) {
367 : 0 : throw std::runtime_error(strprintf(
368 : : "Receive interrupted (received %u bytes without terminator before that)",
369 [ - + + - : 112 : data.size()));
+ - ]
370 : : }
371 : :
372 : : // Wait for a short while (or the socket to become ready for reading) before retrying.
373 [ + - ]: 23487 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
374 [ + - ]: 23487 : (void)Wait(wait_time, RECV);
375 : : }
376 : 372 : }
377 : :
378 : 0 : bool Sock::IsConnected(std::string& errmsg) const
379 : : {
380 [ # # ]: 0 : if (m_socket == INVALID_SOCKET) {
381 : 0 : errmsg = "not connected";
382 : 0 : return false;
383 : : }
384 : :
385 : 0 : char c;
386 [ # # # ]: 0 : switch (Recv(&c, sizeof(c), MSG_PEEK)) {
387 : 0 : case -1: {
388 : 0 : const int err = WSAGetLastError();
389 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
390 : 0 : errmsg = NetworkErrorString(err);
391 : 0 : return false;
392 : : }
393 : : return true;
394 : : }
395 : 0 : case 0:
396 : 0 : errmsg = "closed";
397 : 0 : return false;
398 : : default:
399 : : return true;
400 : : }
401 : : }
402 : :
403 : 149984 : void Sock::Close()
404 : : {
405 [ - + ]: 149984 : if (m_socket == INVALID_SOCKET) {
406 : : return;
407 : : }
408 : : #ifdef WIN32
409 : : int ret = closesocket(m_socket);
410 : : #else
411 : 0 : int ret = close(m_socket);
412 : : #endif
413 [ # # ]: 0 : if (ret) {
414 [ # # ]: 0 : LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError()));
415 : : }
416 : 0 : m_socket = INVALID_SOCKET;
417 : : }
418 : :
419 : 0 : bool Sock::operator==(SOCKET s) const
420 : : {
421 : 0 : return m_socket == s;
422 : : };
423 : :
424 : 9697 : std::string NetworkErrorString(int err)
425 : : {
426 : : #if defined(WIN32)
427 : : return Win32ErrorString(err);
428 : : #else
429 : : // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
430 : 9697 : return SysErrorString(err);
431 : : #endif
432 : : }
|