LCOV - code coverage report
Current view: top level - src/util - sock.cpp (source / functions) Coverage Total Hit
Test: test_bitcoin_coverage.info Lines: 52.7 % 169 89
Test Date: 2024-08-28 04:44:32 Functions: 56.0 % 25 14
Branches: 33.3 % 168 56

             Branch data     Line data    Source code
       1                 :             : // Copyright (c) 2020-2022 The Bitcoin Core developers
       2                 :             : // Distributed under the MIT software license, see the accompanying
       3                 :             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       4                 :             : 
       5                 :             : #include <common/system.h>
       6                 :             : #include <compat/compat.h>
       7                 :             : #include <logging.h>
       8                 :             : #include <tinyformat.h>
       9                 :             : #include <util/sock.h>
      10                 :             : #include <util/syserror.h>
      11                 :             : #include <util/threadinterrupt.h>
      12                 :             : #include <util/time.h>
      13                 :             : 
      14                 :             : #include <memory>
      15                 :             : #include <stdexcept>
      16                 :             : #include <string>
      17                 :             : 
      18                 :             : #ifdef USE_POLL
      19                 :             : #include <poll.h>
      20                 :             : #endif
      21                 :             : 
      22                 :           0 : static inline bool IOErrorIsPermanent(int err)
      23                 :             : {
      24                 :           0 :     return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
      25                 :             : }
      26                 :             : 
      27                 :          27 : Sock::Sock(SOCKET s) : m_socket(s) {}
      28                 :             : 
      29                 :           2 : Sock::Sock(Sock&& other)
      30                 :             : {
      31                 :           2 :     m_socket = other.m_socket;
      32                 :           2 :     other.m_socket = INVALID_SOCKET;
      33                 :           2 : }
      34                 :             : 
      35                 :          38 : Sock::~Sock() { Close(); }
      36                 :             : 
      37                 :           2 : Sock& Sock::operator=(Sock&& other)
      38                 :             : {
      39                 :           2 :     Close();
      40                 :           2 :     m_socket = other.m_socket;
      41                 :           2 :     other.m_socket = INVALID_SOCKET;
      42                 :           2 :     return *this;
      43                 :             : }
      44                 :             : 
      45                 :           5 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
      46                 :             : {
      47                 :           5 :     return send(m_socket, static_cast<const char*>(data), len, flags);
      48                 :             : }
      49                 :             : 
      50                 :           4 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
      51                 :             : {
      52         [ +  - ]:           4 :     return recv(m_socket, static_cast<char*>(buf), len, flags);
      53                 :             : }
      54                 :             : 
      55                 :           0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
      56                 :             : {
      57                 :           0 :     return connect(m_socket, addr, addr_len);
      58                 :             : }
      59                 :             : 
      60                 :           0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
      61                 :             : {
      62                 :           0 :     return bind(m_socket, addr, addr_len);
      63                 :             : }
      64                 :             : 
      65                 :           0 : int Sock::Listen(int backlog) const
      66                 :             : {
      67                 :           0 :     return listen(m_socket, backlog);
      68                 :             : }
      69                 :             : 
      70                 :           0 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
      71                 :             : {
      72                 :             : #ifdef WIN32
      73                 :             :     static constexpr auto ERR = INVALID_SOCKET;
      74                 :             : #else
      75                 :           0 :     static constexpr auto ERR = SOCKET_ERROR;
      76                 :             : #endif
      77                 :             : 
      78                 :           0 :     std::unique_ptr<Sock> sock;
      79                 :             : 
      80         [ #  # ]:           0 :     const auto socket = accept(m_socket, addr, addr_len);
      81         [ #  # ]:           0 :     if (socket != ERR) {
      82                 :           0 :         try {
      83         [ #  # ]:           0 :             sock = std::make_unique<Sock>(socket);
      84         [ -  - ]:           0 :         } catch (const std::exception&) {
      85                 :             : #ifdef WIN32
      86                 :             :             closesocket(socket);
      87                 :             : #else
      88         [ -  - ]:           0 :             close(socket);
      89                 :             : #endif
      90                 :           0 :         }
      91                 :             :     }
      92                 :             : 
      93                 :           0 :     return sock;
      94                 :           0 : }
      95                 :             : 
      96                 :           0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
      97                 :             : {
      98                 :           0 :     return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
      99                 :             : }
     100                 :             : 
     101                 :           0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
     102                 :             : {
     103                 :           0 :     return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
     104                 :             : }
     105                 :             : 
     106                 :           0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
     107                 :             : {
     108                 :           0 :     return getsockname(m_socket, name, name_len);
     109                 :             : }
     110                 :             : 
     111                 :           0 : bool Sock::SetNonBlocking() const
     112                 :             : {
     113                 :             : #ifdef WIN32
     114                 :             :     u_long on{1};
     115                 :             :     if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
     116                 :             :         return false;
     117                 :             :     }
     118                 :             : #else
     119                 :           0 :     const int flags{fcntl(m_socket, F_GETFL, 0)};
     120         [ #  # ]:           0 :     if (flags == SOCKET_ERROR) {
     121                 :             :         return false;
     122                 :             :     }
     123         [ #  # ]:           0 :     if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
     124                 :           0 :         return false;
     125                 :             :     }
     126                 :             : #endif
     127                 :             :     return true;
     128                 :             : }
     129                 :             : 
     130                 :           0 : bool Sock::IsSelectable() const
     131                 :             : {
     132                 :             : #if defined(USE_POLL) || defined(WIN32)
     133                 :           0 :     return true;
     134                 :             : #else
     135                 :             :     return m_socket < FD_SETSIZE;
     136                 :             : #endif
     137                 :             : }
     138                 :             : 
     139                 :           2 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
     140                 :             : {
     141                 :             :     // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
     142                 :             :     // `this` to be destroyed when the `shared_ptr` goes out of scope at the
     143                 :             :     // end of this function. Create it with a custom noop deleter.
     144                 :           2 :     std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
     145                 :             : 
     146   [ +  -  -  +  :           8 :     EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
          +  +  +  -  -  
                -  -  - ]
     147                 :             : 
     148   [ +  -  +  - ]:           2 :     if (!WaitMany(timeout, events_per_sock)) {
     149                 :             :         return false;
     150                 :             :     }
     151                 :             : 
     152         [ -  + ]:           2 :     if (occurred != nullptr) {
     153                 :           0 :         *occurred = events_per_sock.begin()->second.occurred;
     154                 :             :     }
     155                 :             : 
     156                 :             :     return true;
     157   [ -  +  +  - ]:           6 : }
     158                 :             : 
     159                 :           2 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
     160                 :             : {
     161                 :             : #ifdef USE_POLL
     162                 :           2 :     std::vector<pollfd> pfds;
     163   [ +  -  +  + ]:           4 :     for (const auto& [sock, events] : events_per_sock) {
     164         [ +  - ]:           2 :         pfds.emplace_back();
     165                 :           2 :         auto& pfd = pfds.back();
     166         [ +  - ]:           2 :         pfd.fd = sock->m_socket;
     167         [ +  - ]:           2 :         if (events.requested & RECV) {
     168                 :           2 :             pfd.events |= POLLIN;
     169                 :             :         }
     170         [ -  + ]:           2 :         if (events.requested & SEND) {
     171                 :           0 :             pfd.events |= POLLOUT;
     172                 :             :         }
     173                 :             :     }
     174                 :             : 
     175   [ +  -  +  - ]:           4 :     if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
     176                 :             :         return false;
     177                 :             :     }
     178                 :             : 
     179         [ -  + ]:           2 :     assert(pfds.size() == events_per_sock.size());
     180                 :           2 :     size_t i{0};
     181   [ -  +  +  + ]:           4 :     for (auto& [sock, events] : events_per_sock) {
     182         [ -  + ]:           2 :         assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
     183                 :           2 :         events.occurred = 0;
     184         [ +  - ]:           2 :         if (pfds[i].revents & POLLIN) {
     185                 :           2 :             events.occurred |= RECV;
     186                 :             :         }
     187         [ -  + ]:           2 :         if (pfds[i].revents & POLLOUT) {
     188                 :           0 :             events.occurred |= SEND;
     189                 :             :         }
     190         [ -  + ]:           2 :         if (pfds[i].revents & (POLLERR | POLLHUP)) {
     191                 :           0 :             events.occurred |= ERR;
     192                 :             :         }
     193                 :           2 :         ++i;
     194                 :             :     }
     195                 :             : 
     196                 :             :     return true;
     197                 :             : #else
     198                 :             :     fd_set recv;
     199                 :             :     fd_set send;
     200                 :             :     fd_set err;
     201                 :             :     FD_ZERO(&recv);
     202                 :             :     FD_ZERO(&send);
     203                 :             :     FD_ZERO(&err);
     204                 :             :     SOCKET socket_max{0};
     205                 :             : 
     206                 :             :     for (const auto& [sock, events] : events_per_sock) {
     207                 :             :         if (!sock->IsSelectable()) {
     208                 :             :             return false;
     209                 :             :         }
     210                 :             :         const auto& s = sock->m_socket;
     211                 :             :         if (events.requested & RECV) {
     212                 :             :             FD_SET(s, &recv);
     213                 :             :         }
     214                 :             :         if (events.requested & SEND) {
     215                 :             :             FD_SET(s, &send);
     216                 :             :         }
     217                 :             :         FD_SET(s, &err);
     218                 :             :         socket_max = std::max(socket_max, s);
     219                 :             :     }
     220                 :             : 
     221                 :             :     timeval tv = MillisToTimeval(timeout);
     222                 :             : 
     223                 :             :     if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
     224                 :             :         return false;
     225                 :             :     }
     226                 :             : 
     227                 :             :     for (auto& [sock, events] : events_per_sock) {
     228                 :             :         const auto& s = sock->m_socket;
     229                 :             :         events.occurred = 0;
     230                 :             :         if (FD_ISSET(s, &recv)) {
     231                 :             :             events.occurred |= RECV;
     232                 :             :         }
     233                 :             :         if (FD_ISSET(s, &send)) {
     234                 :             :             events.occurred |= SEND;
     235                 :             :         }
     236                 :             :         if (FD_ISSET(s, &err)) {
     237                 :             :             events.occurred |= ERR;
     238                 :             :         }
     239                 :             :     }
     240                 :             : 
     241                 :             :     return true;
     242                 :             : #endif /* USE_POLL */
     243                 :           2 : }
     244                 :             : 
     245                 :          34 : void Sock::SendComplete(Span<const unsigned char> data,
     246                 :             :                         std::chrono::milliseconds timeout,
     247                 :             :                         CThreadInterrupt& interrupt) const
     248                 :             : {
     249                 :          34 :     const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
     250                 :          34 :     size_t sent{0};
     251                 :             : 
     252                 :          34 :     for (;;) {
     253                 :          34 :         const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
     254                 :             : 
     255         [ +  - ]:          34 :         if (ret > 0) {
     256                 :          34 :             sent += static_cast<size_t>(ret);
     257         [ -  + ]:          34 :             if (sent == data.size()) {
     258                 :             :                 break;
     259                 :             :             }
     260                 :             :         } else {
     261                 :           0 :             const int err{WSAGetLastError()};
     262         [ #  # ]:           0 :             if (IOErrorIsPermanent(err)) {
     263   [ #  #  #  #  :           0 :                 throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
                   #  # ]
     264                 :             :             }
     265                 :             :         }
     266                 :             : 
     267                 :           0 :         const auto now = GetTime<std::chrono::milliseconds>();
     268                 :             : 
     269         [ #  # ]:           0 :         if (now >= deadline) {
     270                 :           0 :             throw std::runtime_error(strprintf(
     271   [ #  #  #  # ]:           0 :                 "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
     272                 :             :         }
     273                 :             : 
     274         [ #  # ]:           0 :         if (interrupt) {
     275                 :           0 :             throw std::runtime_error(strprintf(
     276   [ #  #  #  # ]:           0 :                 "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
     277                 :             :         }
     278                 :             : 
     279                 :             :         // Wait for a short while (or the socket to become ready for sending) before retrying
     280                 :             :         // if nothing was sent.
     281                 :           0 :         const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
     282                 :           0 :         (void)Wait(wait_time, SEND);
     283                 :           0 :     }
     284                 :          34 : }
     285                 :             : 
     286                 :          34 : void Sock::SendComplete(Span<const char> data,
     287                 :             :                         std::chrono::milliseconds timeout,
     288                 :             :                         CThreadInterrupt& interrupt) const
     289                 :             : {
     290                 :          34 :     SendComplete(MakeUCharSpan(data), timeout, interrupt);
     291                 :          34 : }
     292                 :             : 
     293                 :          38 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
     294                 :             :                                       std::chrono::milliseconds timeout,
     295                 :             :                                       CThreadInterrupt& interrupt,
     296                 :             :                                       size_t max_data) const
     297                 :             : {
     298                 :          38 :     const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
     299                 :          38 :     std::string data;
     300                 :          38 :     bool terminator_found{false};
     301                 :             : 
     302                 :             :     // We must not consume any bytes past the terminator from the socket.
     303                 :             :     // One option is to read one byte at a time and check if we have read a terminator.
     304                 :             :     // However that is very slow. Instead, we peek at what is in the socket and only read
     305                 :             :     // as many bytes as possible without crossing the terminator.
     306                 :             :     // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
     307                 :             :     // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
     308                 :             :     // at a time is about 50 times slower.
     309                 :             : 
     310                 :         169 :     for (;;) {
     311         [ +  + ]:         169 :         if (data.size() >= max_data) {
     312                 :           2 :             throw std::runtime_error(
     313   [ +  -  +  - ]:           4 :                 strprintf("Received too many bytes without a terminator (%u)", data.size()));
     314                 :             :         }
     315                 :             : 
     316                 :         167 :         char buf[512];
     317                 :             : 
     318   [ +  +  +  - ]:         333 :         const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
     319                 :             : 
     320      [ -  -  + ]:         167 :         switch (peek_ret) {
     321                 :           0 :         case -1: {
     322                 :           0 :             const int err{WSAGetLastError()};
     323         [ #  # ]:           0 :             if (IOErrorIsPermanent(err)) {
     324   [ #  #  #  #  :           0 :                 throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
                   #  # ]
     325                 :             :             }
     326                 :             :             break;
     327                 :             :         }
     328                 :           0 :         case 0:
     329         [ #  # ]:           0 :             throw std::runtime_error("Connection unexpectedly closed by peer");
     330                 :         167 :         default:
     331                 :         167 :             auto end = buf + peek_ret;
     332                 :         167 :             auto terminator_pos = std::find(buf, end, terminator);
     333                 :         167 :             terminator_found = terminator_pos != end;
     334                 :             : 
     335         [ +  + ]:         167 :             const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
     336                 :             :                                                     static_cast<size_t>(peek_ret)};
     337                 :             : 
     338         [ +  - ]:         167 :             const ssize_t read_ret{Recv(buf, try_len, 0)};
     339                 :             : 
     340   [ +  -  -  + ]:         167 :             if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
     341                 :           0 :                 throw std::runtime_error(
     342         [ #  # ]:           0 :                     strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
     343                 :             :                               "peek claimed %u bytes are available",
     344         [ #  # ]:           0 :                               read_ret, try_len, peek_ret));
     345                 :             :             }
     346                 :             : 
     347                 :             :             // Don't include the terminator in the output.
     348         [ +  + ]:         167 :             const size_t append_len{terminator_found ? try_len - 1 : try_len};
     349                 :             : 
     350         [ +  - ]:         167 :             data.append(buf, buf + append_len);
     351                 :             : 
     352         [ +  + ]:         167 :             if (terminator_found) {
     353                 :          36 :                 return data;
     354                 :             :             }
     355                 :             :         }
     356                 :             : 
     357                 :         131 :         const auto now = GetTime<std::chrono::milliseconds>();
     358                 :             : 
     359         [ -  + ]:         131 :         if (now >= deadline) {
     360         [ #  # ]:           0 :             throw std::runtime_error(strprintf(
     361   [ #  #  #  # ]:           0 :                 "Receive timeout (received %u bytes without terminator before that)", data.size()));
     362                 :             :         }
     363                 :             : 
     364   [ +  -  -  + ]:         131 :         if (interrupt) {
     365         [ #  # ]:           0 :             throw std::runtime_error(strprintf(
     366                 :             :                 "Receive interrupted (received %u bytes without terminator before that)",
     367   [ #  #  #  # ]:           0 :                 data.size()));
     368                 :             :         }
     369                 :             : 
     370                 :             :         // Wait for a short while (or the socket to become ready for reading) before retrying.
     371         [ +  - ]:         131 :         const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
     372         [ +  - ]:         131 :         (void)Wait(wait_time, RECV);
     373                 :             :     }
     374                 :           2 : }
     375                 :             : 
     376                 :           0 : bool Sock::IsConnected(std::string& errmsg) const
     377                 :             : {
     378         [ #  # ]:           0 :     if (m_socket == INVALID_SOCKET) {
     379                 :           0 :         errmsg = "not connected";
     380                 :           0 :         return false;
     381                 :             :     }
     382                 :             : 
     383                 :           0 :     char c;
     384      [ #  #  # ]:           0 :     switch (Recv(&c, sizeof(c), MSG_PEEK)) {
     385                 :           0 :     case -1: {
     386                 :           0 :         const int err = WSAGetLastError();
     387         [ #  # ]:           0 :         if (IOErrorIsPermanent(err)) {
     388                 :           0 :             errmsg = NetworkErrorString(err);
     389                 :           0 :             return false;
     390                 :             :         }
     391                 :             :         return true;
     392                 :             :     }
     393                 :           0 :     case 0:
     394                 :           0 :         errmsg = "closed";
     395                 :           0 :         return false;
     396                 :             :     default:
     397                 :             :         return true;
     398                 :             :     }
     399                 :             : }
     400                 :             : 
     401                 :          31 : void Sock::Close()
     402                 :             : {
     403         [ +  + ]:          31 :     if (m_socket == INVALID_SOCKET) {
     404                 :             :         return;
     405                 :             :     }
     406                 :             : #ifdef WIN32
     407                 :             :     int ret = closesocket(m_socket);
     408                 :             : #else
     409                 :          10 :     int ret = close(m_socket);
     410                 :             : #endif
     411         [ -  + ]:          10 :     if (ret) {
     412         [ #  # ]:           0 :         LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
     413                 :             :     }
     414                 :          10 :     m_socket = INVALID_SOCKET;
     415                 :             : }
     416                 :             : 
     417                 :           4 : bool Sock::operator==(SOCKET s) const
     418                 :             : {
     419                 :           4 :     return m_socket == s;
     420                 :             : };
     421                 :             : 
     422                 :           0 : std::string NetworkErrorString(int err)
     423                 :             : {
     424                 :             : #if defined(WIN32)
     425                 :             :     return Win32ErrorString(err);
     426                 :             : #else
     427                 :             :     // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
     428                 :           0 :     return SysErrorString(err);
     429                 :             : #endif
     430                 :             : }
        

Generated by: LCOV version 2.0-1