summaryrefslogtreecommitdiff
path: root/zen/socket.h
blob: 461226d07fac8ab3f5d9048fd7c0081a9241505c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
// *****************************************************************************
// * This file is part of the FreeFileSync project. It is distributed under    *
// * GNU General Public License: https://www.gnu.org/licenses/gpl-3.0          *
// * Copyright (C) Zenju (zenju AT freefilesync DOT org) - All Rights Reserved *
// *****************************************************************************

#ifndef SOCKET_H_23498325972583947678456437
#define SOCKET_H_23498325972583947678456437

#include "sys_error.h"
    #include <unistd.h> //close
    #include <sys/socket.h>
    #include <netinet/tcp.h> //TCP_NODELAY
    #include <netdb.h> //getaddrinfo


namespace zen
{
#define THROW_LAST_SYS_ERROR_WSA(functionName)                       \
    do { const ErrorCode ecInternal = getLastError(); throw SysError(formatSystemError(functionName, ecInternal)); } while (false)


#define THROW_LAST_SYS_ERROR_GAI(rcGai)                        \
    do {                                                       \
        if (rcGai == EAI_SYSTEM) /*"check errno for details"*/ \
            THROW_LAST_SYS_ERROR("getaddrinfo");               \
        \
        throw SysError(formatSystemError("getaddrinfo", formatGaiErrorCode(rcGai), utfTo<std::wstring>(::gai_strerror(rcGai)))); \
    } while (false)

inline
std::wstring formatGaiErrorCode(int ec)
{
    switch (ec) //codes used on both Linux and macOS
    {
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_ADDRFAMILY);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_AGAIN);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_BADFLAGS);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_FAIL);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_FAMILY);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_MEMORY);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_NODATA);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_NONAME);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_SERVICE);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_SOCKTYPE);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_SYSTEM);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_OVERFLOW);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_INPROGRESS);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_CANCELED);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_NOTCANCELED);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_ALLDONE);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_INTR);
            ZEN_CHECK_CASE_FOR_CONSTANT(EAI_IDN_ENCODE);
        default:
            return replaceCpy(_("Error code %x"), L"%x", numberTo<std::wstring>(ec));
    }
}

//patch up socket portability:
using SocketType = int;
const SocketType invalidSocket = -1;
inline void closeSocket(SocketType s) { ::close(s); }

void setNonBlocking(SocketType socket, bool value); //throw SysError


//Winsock needs to be initialized before calling any of these functions! (WSAStartup/WSACleanup)



class Socket //throw SysError
{
public:
    Socket(const Zstring& server, const Zstring& serviceName, int timeoutSec) //throw SysError
    {
        //GetAddrInfo(): "If the pNodeName parameter contains an empty string, all registered addresses on the local computer are returned."
        //               "If the pNodeName parameter points to a string equal to "localhost", all loopback addresses on the local computer are returned."
        if (trimCpy(server).empty())
            throw SysError(_("Server name must not be empty."));

        const addrinfo hints
        {
            .ai_flags = AI_ADDRCONFIG, //save a AAAA lookup on machines that can't use the returned data anyhow
            .ai_socktype = SOCK_STREAM, //we *do* care about this one!
        };

        addrinfo* servinfo = nullptr;
        ZEN_ON_SCOPE_EXIT(if (servinfo) ::freeaddrinfo(servinfo));

        const int rcGai = ::getaddrinfo(server.c_str(), serviceName.c_str(), &hints, &servinfo);
        if (rcGai != 0)
            THROW_LAST_SYS_ERROR_GAI(rcGai);
        if (!servinfo)
            throw SysError(formatSystemError("getaddrinfo", L"", L"Empty server info."));

        const auto getConnectedSocket = [timeoutSec](const auto& /*addrinfo*/ ai)
        {
            SocketType testSocket = ::socket(ai.ai_family,    //int socket_family
                                             SOCK_CLOEXEC | SOCK_NONBLOCK |
                                             ai.ai_socktype,  //int socket_type
                                             ai.ai_protocol); //int protocol
            if (testSocket == invalidSocket)
                THROW_LAST_SYS_ERROR_WSA("socket");
            ZEN_ON_SCOPE_FAIL(closeSocket(testSocket));

            if (::connect(testSocket, ai.ai_addr, static_cast<int>(ai.ai_addrlen)) != 0) //0 or SOCKET_ERROR(-1)
            {
                if (errno != EINPROGRESS)
                    THROW_LAST_SYS_ERROR_WSA("connect");

                fd_set writefds{};
                fd_set exceptfds{}; //mostly only relevant for connect()
                FD_SET(testSocket, &writefds);
                FD_SET(testSocket, &exceptfds);

                /*const*/ timeval tv{.tv_sec = timeoutSec};

                const int rv = ::select(
                                   testSocket + 1, //int nfds = "highest-numbered file descriptor in any of the three sets, plus 1"
                                   nullptr,       //fd_set* readfds
                                   &writefds,     //fd_set* writefds
                                   &exceptfds,    //fd_set* exceptfds
                                   &tv);          //const timeval* timeout
                if (rv < 0)
                    THROW_LAST_SYS_ERROR_WSA("select");

                if (rv == 0) //time-out!
                    throw SysError(formatSystemError("select, " + utfTo<std::string>(_P("1 sec", "%x sec", timeoutSec)), ETIMEDOUT));
                int error = 0;
                socklen_t optLen = sizeof(error);
                if (::getsockopt(testSocket, //[in]      SOCKET s
                                 SOL_SOCKET, //[in]      int    level
                                 SO_ERROR,   //[in]      int    optname
                                 reinterpret_cast<char*>(&error), //[out]     char*   optval
                                 &optLen)    //[in, out] socklen_t* optlen
                    != 0)
                    THROW_LAST_SYS_ERROR_WSA("getsockopt(SO_ERROR)");

                if (error != 0)
                    throw SysError(formatSystemError("connect, SO_ERROR", static_cast<ErrorCode>(error))/*== system error code, apparently!?*/);
            }

            setNonBlocking(testSocket, false); //throw SysError

            return testSocket;
        };

        /* getAddrInfo() often returns only one ai_family == AF_INET address, but more items are possible:
            facebook.com:  1 x AF_INET6, 3 x AF_INET
            microsoft.com: 5 x AF_INET            => server not allowing connection: hanging for 5x timeoutSec :(       */
        std::optional<SysError> firstError;
        for (const auto* /*::addrinfo*/ si = servinfo; si; si = si->ai_next)
            try
            {
                socket_ = getConnectedSocket(*si); //throw SysError; pass ownership
                firstError = std::nullopt;
                break;
            }
            catch (const SysError& e) { if (!firstError) firstError = e; }

        if (firstError)
            throw* firstError;
        assert(socket_ != invalidSocket); //list was non-empty, so there's either an error, or a valid socket
        ZEN_ON_SCOPE_FAIL(closeSocket(socket_));
        //-----------------------------------------------------------
        //configure *after* selecting appropriate socket: cfg-failure should not discard otherwise fine connection!

        int noDelay =  1; //disable Nagle algorithm: https://brooker.co.za/blog/2024/05/09/nagle.html
        //e.g. test case "website sync": 23% shorter comparison time!
        if (::setsockopt(socket_,                                 //_In_       SOCKET s
                         IPPROTO_TCP,                             //_In_       int    level
                         TCP_NODELAY,                             //_In_       int    optname
                         reinterpret_cast<const char*>(&noDelay), //_In_ const char*  optval
                         sizeof(noDelay)) != 0)                   //_In_       int    optlen
            THROW_LAST_SYS_ERROR_WSA("setsockopt(TCP_NODELAY)");
    }

    ~Socket() { closeSocket(socket_); }

    SocketType get() const { return socket_; }

private:
    Socket           (const Socket&) = delete;
    Socket& operator=(const Socket&) = delete;

    SocketType socket_ = invalidSocket;
};


//more socket helper functions:
namespace
{
size_t tryReadSocket(SocketType socket, void* buffer, size_t bytesToRead) //throw SysError; may return short, only 0 means EOF!
{
    if (bytesToRead == 0) //"read() with a count of 0 returns zero" => indistinguishable from end of file! => check!
        throw std::logic_error(std::string(__FILE__) + '[' + numberTo<std::string>(__LINE__) + "] Contract violation!");

    int bytesReceived = 0;
    for (;;)
    {
        bytesReceived = ::recv(socket,                        //_In_  SOCKET s
                               static_cast<char*>(buffer),    //_Out_ char*  buf
                               static_cast<int>(bytesToRead), //_In_  int    len
                               0);                            //_In_  int    flags
        if (bytesReceived >= 0 || errno != EINTR)
            break;
    }
    if (bytesReceived < 0)
        THROW_LAST_SYS_ERROR_WSA("recv");

    ASSERT_SYSERROR(makeUnsigned(bytesReceived) <= bytesToRead); //better safe than sorry

    return bytesReceived; //"zero indicates end of file"
}


size_t tryWriteSocket(SocketType socket, const void* buffer, size_t bytesToWrite) //throw SysError; may return short! CONTRACT: bytesToWrite > 0
{
    if (bytesToWrite == 0)
        throw std::logic_error(std::string(__FILE__) + '[' + numberTo<std::string>(__LINE__) + "] Contract violation!");

    int bytesWritten = 0;
    for (;;)
    {
        bytesWritten = ::send(socket,                           //_In_       SOCKET s
                              static_cast<const char*>(buffer), //_In_ const char*  buf
                              static_cast<int>(bytesToWrite),   //_In_       int    len
                              0);                               //_In_       int    flags
        if (bytesWritten >= 0 || errno != EINTR)
            break;
    }
    if (bytesWritten < 0)
        THROW_LAST_SYS_ERROR_WSA("send");

    if (bytesWritten == 0)
        throw SysError(formatSystemError("send", L"", L"Zero bytes processed."));

    ASSERT_SYSERROR(makeUnsigned(bytesWritten) <= bytesToWrite); //better safe than sorry

    return bytesWritten;
}
}


//initiate termination of connection by sending TCP FIN package
inline
void shutdownSocketSend(SocketType socket) //throw SysError
{
    if (::shutdown(socket, SHUT_WR) != 0)
        THROW_LAST_SYS_ERROR_WSA("shutdown");
}


inline
void setNonBlocking(SocketType socket, bool nonBlocking) //throw SysError
{
    int flags = ::fcntl(socket, F_GETFL);
    if (flags == -1)
        THROW_LAST_SYS_ERROR("fcntl(F_GETFL)");

    if (nonBlocking)
        flags |= O_NONBLOCK;
    else
        flags &= ~O_NONBLOCK;

    if (::fcntl(socket, F_SETFL, flags) != 0)
        THROW_LAST_SYS_ERROR(nonBlocking ? "fcntl(F_SETFL, O_NONBLOCK)" : "fcntl(F_SETFL, ~O_NONBLOCK)");
}
}

#endif //SOCKET_H_23498325972583947678456437
bgstack15