Branch data Line data Source code
1 : : // Copyright (c) 2020-present The Bitcoin Core developers
2 : : // Distributed under the MIT software license, see the accompanying
3 : : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
4 : :
5 : : #include <util/sock.h>
6 : :
7 : : #include <common/system.h>
8 : : #include <compat/compat.h>
9 : : #include <span.h>
10 : : #include <tinyformat.h>
11 : : #include <util/log.h>
12 : : #include <util/syserror.h>
13 : : #include <util/threadinterrupt.h>
14 : : #include <util/time.h>
15 : :
16 : : #include <memory>
17 : : #include <stdexcept>
18 : : #include <string>
19 : :
20 : : #ifdef USE_POLL
21 : : #include <poll.h>
22 : : #endif
23 : :
24 : 0 : static inline bool IOErrorIsPermanent(int err)
25 : : {
26 : 0 : return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
27 : : }
28 : :
29 : 41 : Sock::Sock(SOCKET s) : m_socket(s) {}
30 : :
31 : 2 : Sock::Sock(Sock&& other)
32 : : {
33 : 2 : m_socket = other.m_socket;
34 : 2 : other.m_socket = INVALID_SOCKET;
35 : 2 : }
36 : :
37 : 55 : Sock::~Sock() { Close(); }
38 : :
39 : 2 : Sock& Sock::operator=(Sock&& other)
40 : : {
41 : 2 : Close();
42 : 2 : m_socket = other.m_socket;
43 : 2 : other.m_socket = INVALID_SOCKET;
44 : 2 : return *this;
45 : : }
46 : :
47 : 8 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
48 : : {
49 : 8 : return send(m_socket, static_cast<const char*>(data), len, flags);
50 : : }
51 : :
52 : 7 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
53 : : {
54 [ + - ]: 7 : return recv(m_socket, static_cast<char*>(buf), len, flags);
55 : : }
56 : :
57 : 1 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
58 : : {
59 : 1 : return connect(m_socket, addr, addr_len);
60 : : }
61 : :
62 : 1 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
63 : : {
64 : 1 : return bind(m_socket, addr, addr_len);
65 : : }
66 : :
67 : 0 : int Sock::Listen(int backlog) const
68 : : {
69 : 0 : return listen(m_socket, backlog);
70 : : }
71 : :
72 : 0 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
73 : : {
74 : : #ifdef WIN32
75 : : static constexpr auto ERR = INVALID_SOCKET;
76 : : #else
77 : 0 : static constexpr auto ERR = SOCKET_ERROR;
78 : : #endif
79 : :
80 : 0 : std::unique_ptr<Sock> sock;
81 : :
82 [ # # ]: 0 : const auto socket = accept(m_socket, addr, addr_len);
83 [ # # ]: 0 : if (socket != ERR) {
84 : 0 : try {
85 [ # # ]: 0 : sock = std::make_unique<Sock>(socket);
86 [ - - ]: 0 : } catch (const std::exception&) {
87 : : #ifdef WIN32
88 : : closesocket(socket);
89 : : #else
90 [ - - ]: 0 : close(socket);
91 : : #endif
92 : 0 : }
93 : : }
94 : :
95 : 0 : return sock;
96 : 0 : }
97 : :
98 : 0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
99 : : {
100 : 0 : return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
101 : : }
102 : :
103 : 0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
104 : : {
105 : 0 : return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
106 : : }
107 : :
108 : 1 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
109 : : {
110 : 1 : return getsockname(m_socket, name, name_len);
111 : : }
112 : :
113 : 1 : bool Sock::SetNonBlocking() const
114 : : {
115 : : #ifdef WIN32
116 : : u_long on{1};
117 : : if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
118 : : return false;
119 : : }
120 : : #else
121 : 1 : const int flags{fcntl(m_socket, F_GETFL, 0)};
122 [ + - ]: 1 : if (flags == SOCKET_ERROR) {
123 : : return false;
124 : : }
125 [ - + ]: 1 : if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
126 : 0 : return false;
127 : : }
128 : : #endif
129 : : return true;
130 : : }
131 : :
132 : 1 : bool Sock::IsSelectable() const
133 : : {
134 : : #if defined(USE_POLL) || defined(WIN32)
135 : 1 : return true;
136 : : #else
137 : : return m_socket < FD_SETSIZE;
138 : : #endif
139 : : }
140 : :
141 : 3 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
142 : : {
143 : : // We need a `shared_ptr` holding `this` for `WaitMany()`, but don't want
144 : : // `this` to be destroyed when the `shared_ptr` goes out of scope at the
145 : : // end of this function.
146 : : // Create it with an aliasing shared_ptr that points to `this` without
147 : : // owning it.
148 [ - + ]: 3 : std::shared_ptr<const Sock> shared{std::shared_ptr<const Sock>{}, this};
149 : :
150 [ - + + + : 9 : EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
- + - - -
- ]
151 : :
152 [ + - + - ]: 3 : if (!WaitMany(timeout, events_per_sock)) {
153 : : return false;
154 : : }
155 : :
156 [ + + ]: 3 : if (occurred != nullptr) {
157 : 1 : *occurred = events_per_sock.begin()->second.occurred;
158 : : }
159 : :
160 : : return true;
161 [ - + ]: 6 : }
162 : :
163 : 3 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
164 : : {
165 : : #ifdef USE_POLL
166 : 3 : std::vector<pollfd> pfds;
167 [ + + + - ]: 6 : for (const auto& [sock, events] : events_per_sock) {
168 [ + - ]: 3 : pfds.emplace_back();
169 : 3 : auto& pfd = pfds.back();
170 [ + - ]: 3 : pfd.fd = sock->m_socket;
171 [ + - ]: 3 : if (events.requested & RECV) {
172 : 3 : pfd.events |= POLLIN;
173 : : }
174 [ - + ]: 3 : if (events.requested & SEND) {
175 : 0 : pfd.events |= POLLOUT;
176 : : }
177 : : }
178 : :
179 [ - + + - : 6 : if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
+ - ]
180 : : return false;
181 : : }
182 : :
183 [ - + - + ]: 3 : assert(pfds.size() == events_per_sock.size());
184 : 3 : size_t i{0};
185 [ + + - + ]: 6 : for (auto& [sock, events] : events_per_sock) {
186 [ - + ]: 3 : assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
187 : 3 : events.occurred = 0;
188 [ + + ]: 3 : if (pfds[i].revents & POLLIN) {
189 : 2 : events.occurred |= RECV;
190 : : }
191 [ - + ]: 3 : if (pfds[i].revents & POLLOUT) {
192 : 0 : events.occurred |= SEND;
193 : : }
194 [ + + ]: 3 : if (pfds[i].revents & (POLLERR | POLLHUP)) {
195 : 1 : events.occurred |= ERR;
196 : : }
197 : 3 : ++i;
198 : : }
199 : :
200 : : return true;
201 : : #else
202 : : fd_set recv;
203 : : fd_set send;
204 : : fd_set err;
205 : : FD_ZERO(&recv);
206 : : FD_ZERO(&send);
207 : : FD_ZERO(&err);
208 : : SOCKET socket_max{0};
209 : :
210 : : for (const auto& [sock, events] : events_per_sock) {
211 : : if (!sock->IsSelectable()) {
212 : : return false;
213 : : }
214 : : const auto& s = sock->m_socket;
215 : : if (events.requested & RECV) {
216 : : FD_SET(s, &recv);
217 : : }
218 : : if (events.requested & SEND) {
219 : : FD_SET(s, &send);
220 : : }
221 : : FD_SET(s, &err);
222 : : socket_max = std::max(socket_max, s);
223 : : }
224 : :
225 : : timeval tv = MillisToTimeval(timeout);
226 : :
227 : : if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
228 : : return false;
229 : : }
230 : :
231 : : for (auto& [sock, events] : events_per_sock) {
232 : : const auto& s = sock->m_socket;
233 : : events.occurred = 0;
234 : : if (FD_ISSET(s, &recv)) {
235 : : events.occurred |= RECV;
236 : : }
237 : : if (FD_ISSET(s, &send)) {
238 : : events.occurred |= SEND;
239 : : }
240 : : if (FD_ISSET(s, &err)) {
241 : : events.occurred |= ERR;
242 : : }
243 : : }
244 : :
245 : : return true;
246 : : #endif /* USE_POLL */
247 : 3 : }
248 : :
249 : 34 : void Sock::SendComplete(std::span<const unsigned char> data,
250 : : std::chrono::milliseconds timeout,
251 : : CThreadInterrupt& interrupt) const
252 : : {
253 : 34 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
254 : 34 : size_t sent{0};
255 : :
256 : 34 : for (;;) {
257 : 34 : const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
258 : :
259 [ + - ]: 34 : if (ret > 0) {
260 : 34 : sent += static_cast<size_t>(ret);
261 [ - + ]: 34 : if (sent == data.size()) {
262 : : break;
263 : : }
264 : : } else {
265 : 0 : const int err{WSAGetLastError()};
266 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
267 [ # # # # : 0 : throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
# # ]
268 : : }
269 : : }
270 : :
271 : 0 : const auto now = GetTime<std::chrono::milliseconds>();
272 : :
273 [ # # ]: 0 : if (now >= deadline) {
274 [ # # ]: 0 : throw std::runtime_error(strprintf(
275 [ # # # # ]: 0 : "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
276 : : }
277 : :
278 [ # # ]: 0 : if (interrupt) {
279 [ # # ]: 0 : throw std::runtime_error(strprintf(
280 [ # # # # ]: 0 : "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
281 : : }
282 : :
283 : : // Wait for a short while (or the socket to become ready for sending) before retrying
284 : : // if nothing was sent.
285 : 0 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
286 : 0 : (void)Wait(wait_time, SEND);
287 : 0 : }
288 : 34 : }
289 : :
290 : 34 : void Sock::SendComplete(std::span<const char> data,
291 : : std::chrono::milliseconds timeout,
292 : : CThreadInterrupt& interrupt) const
293 : : {
294 : 34 : SendComplete(MakeUCharSpan(data), timeout, interrupt);
295 : 34 : }
296 : :
297 : 38 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
298 : : std::chrono::milliseconds timeout,
299 : : CThreadInterrupt& interrupt,
300 : : size_t max_data) const
301 : : {
302 : 38 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
303 : 38 : std::string data;
304 : 38 : bool terminator_found{false};
305 : :
306 : : // We must not consume any bytes past the terminator from the socket.
307 : : // One option is to read one byte at a time and check if we have read a terminator.
308 : : // However that is very slow. Instead, we peek at what is in the socket and only read
309 : : // as many bytes as possible without crossing the terminator.
310 : : // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
311 : : // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
312 : : // at a time is about 50 times slower.
313 : :
314 : 169 : for (;;) {
315 [ - + + + ]: 169 : if (data.size() >= max_data) {
316 : 2 : throw std::runtime_error(
317 [ - + + - : 4 : strprintf("Received too many bytes without a terminator (%u)", data.size()));
+ - ]
318 : : }
319 : :
320 : 167 : char buf[512];
321 : :
322 [ + + + - ]: 333 : const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
323 : :
324 [ - - + ]: 167 : switch (peek_ret) {
325 : 0 : case -1: {
326 : 0 : const int err{WSAGetLastError()};
327 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
328 [ # # # # : 0 : throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
# # ]
329 : : }
330 : : break;
331 : : }
332 : 0 : case 0:
333 [ # # ]: 0 : throw std::runtime_error("Connection unexpectedly closed by peer");
334 : 167 : default:
335 : 167 : auto end = buf + peek_ret;
336 : 167 : auto terminator_pos = std::find(buf, end, terminator);
337 : 167 : terminator_found = terminator_pos != end;
338 : :
339 [ + + ]: 167 : const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
340 : : static_cast<size_t>(peek_ret)};
341 : :
342 [ + - ]: 167 : const ssize_t read_ret{Recv(buf, try_len, 0)};
343 : :
344 [ + - - + ]: 167 : if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
345 : 0 : throw std::runtime_error(
346 [ # # ]: 0 : strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
347 : : "peek claimed %u bytes are available",
348 [ # # ]: 0 : read_ret, try_len, peek_ret));
349 : : }
350 : :
351 : : // Don't include the terminator in the output.
352 [ + + ]: 167 : const size_t append_len{terminator_found ? try_len - 1 : try_len};
353 : :
354 [ + - ]: 167 : data.append(buf, buf + append_len);
355 : :
356 [ + + ]: 167 : if (terminator_found) {
357 : 36 : return data;
358 : : }
359 : : }
360 : :
361 : 131 : const auto now = GetTime<std::chrono::milliseconds>();
362 : :
363 [ - + ]: 131 : if (now >= deadline) {
364 : 0 : throw std::runtime_error(strprintf(
365 [ # # # # : 0 : "Receive timeout (received %u bytes without terminator before that)", data.size()));
# # ]
366 : : }
367 : :
368 [ + - - + ]: 131 : if (interrupt) {
369 : 0 : throw std::runtime_error(strprintf(
370 : : "Receive interrupted (received %u bytes without terminator before that)",
371 [ # # # # : 0 : data.size()));
# # ]
372 : : }
373 : :
374 : : // Wait for a short while (or the socket to become ready for reading) before retrying.
375 [ + - ]: 131 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
376 [ + - ]: 131 : (void)Wait(wait_time, RECV);
377 : : }
378 : 2 : }
379 : :
380 : 0 : bool Sock::IsConnected(std::string& errmsg) const
381 : : {
382 [ # # ]: 0 : if (m_socket == INVALID_SOCKET) {
383 : 0 : errmsg = "not connected";
384 : 0 : return false;
385 : : }
386 : :
387 : 0 : char c;
388 [ # # # ]: 0 : switch (Recv(&c, sizeof(c), MSG_PEEK)) {
389 : 0 : case -1: {
390 : 0 : const int err = WSAGetLastError();
391 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
392 : 0 : errmsg = NetworkErrorString(err);
393 : 0 : return false;
394 : : }
395 : : return true;
396 : : }
397 : 0 : case 0:
398 : 0 : errmsg = "closed";
399 : 0 : return false;
400 : : default:
401 : : return true;
402 : : }
403 : : }
404 : :
405 : 45 : void Sock::Close()
406 : : {
407 [ + + ]: 45 : if (m_socket == INVALID_SOCKET) {
408 : : return;
409 : : }
410 : : #ifdef WIN32
411 : : int ret = closesocket(m_socket);
412 : : #else
413 : 13 : int ret = close(m_socket);
414 : : #endif
415 [ - + ]: 13 : if (ret) {
416 [ # # ]: 0 : LogWarning("Error closing socket %d: %s", m_socket, NetworkErrorString(WSAGetLastError()));
417 : : }
418 : 13 : m_socket = INVALID_SOCKET;
419 : : }
420 : :
421 : 4 : bool Sock::operator==(SOCKET s) const
422 : : {
423 : 4 : return m_socket == s;
424 : : };
425 : :
426 : 2 : std::string NetworkErrorString(int err)
427 : : {
428 : : #if defined(WIN32)
429 : : return Win32ErrorString(err);
430 : : #else
431 : : // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
432 : 2 : return SysErrorString(err);
433 : : #endif
434 : : }
|