summaryrefslogtreecommitdiff
path: root/zen/socket.h
blob: f981385252f52308c800c7268bc815b27eaabe7f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// *****************************************************************************
// * This file is part of the FreeFileSync project. It is distributed under    *
// * GNU General Public License: https://www.gnu.org/licenses/gpl-3.0          *
// * Copyright (C) Zenju (zenju AT freefilesync DOT org) - All Rights Reserved *
// *****************************************************************************

#ifndef SOCKET_H_23498325972583947678456437
#define SOCKET_H_23498325972583947678456437

#include <zen/zstring.h>
#include "sys_error.h"
    #include <unistd.h> //close
    #include <sys/socket.h>
    #include <netdb.h> //getaddrinfo



namespace zen
{
#define THROW_LAST_SYS_ERROR_WSA(functionName)                       \
    do { const ErrorCode ecInternal = getLastError(); throw SysError(formatSystemError(functionName, ecInternal)); } while (false)


//patch up socket portability:
using SocketType = int;
const SocketType invalidSocket = -1;
inline void closeSocket(SocketType s) { ::close(s); }


//Winsock needs to be initialized before calling any of these functions! (WSAStartup/WSACleanup)

class Socket //throw SysError
{
public:
    Socket(const Zstring& server, const Zstring& serviceName) //throw SysError
    {
        ::addrinfo hints = {};
        hints.ai_socktype = SOCK_STREAM; //we *do* care about this one!
        hints.ai_flags = AI_ADDRCONFIG; //save a AAAA lookup on machines that can't use the returned data anyhow

        ::addrinfo* servinfo = nullptr;
        ZEN_ON_SCOPE_EXIT(if (servinfo) ::freeaddrinfo(servinfo));

        const int rcGai = ::getaddrinfo(server.c_str(), serviceName.c_str(), &hints, &servinfo);
        if (rcGai != 0)
            throw SysError(formatSystemError("getaddrinfo", replaceCpy(_("Error code %x"), L"%x", numberTo<std::wstring>(rcGai)), utfTo<std::wstring>(::gai_strerror(rcGai))));
        if (!servinfo)
            throw SysError(formatSystemError("getaddrinfo", L"", L"Empty server info."));

        const auto getConnectedSocket = [](const auto& /*::addrinfo*/ ai)
        {
            SocketType testSocket = ::socket(ai.ai_family,    //int socket_family
                                             SOCK_CLOEXEC |
                                             ai.ai_socktype,  //int socket_type
                                             ai.ai_protocol); //int protocol
            if (testSocket == invalidSocket)
                THROW_LAST_SYS_ERROR_WSA("socket");
            ZEN_ON_SCOPE_FAIL(closeSocket(testSocket));


            if (::connect(testSocket, ai.ai_addr, static_cast<int>(ai.ai_addrlen)) != 0)
                THROW_LAST_SYS_ERROR_WSA("connect");

            return testSocket;
        };

        std::optional<SysError> firstError;
        for (const auto* /*::addrinfo*/ si = servinfo; si; si = si->ai_next)
            try
            {
                socket_ = getConnectedSocket(*si); //throw SysError; pass ownership
                return;
            }
            catch (const SysError& e) { if (!firstError) firstError = e; }

        throw* firstError; //list was not empty, so there must have been an error!
    }

    ~Socket() { closeSocket(socket_); }

    SocketType get() const { return socket_; }

private:
    Socket           (const Socket&) = delete;
    Socket& operator=(const Socket&) = delete;

    SocketType socket_ = invalidSocket;
};


//more socket helper functions:
namespace
{
size_t tryReadSocket(SocketType socket, void* buffer, size_t bytesToRead) //throw SysError; may return short, only 0 means EOF!
{
    if (bytesToRead == 0) //"read() with a count of 0 returns zero" => indistinguishable from end of file! => check!
        throw std::logic_error("Contract violation! " + std::string(__FILE__) + ':' + numberTo<std::string>(__LINE__));

    int bytesReceived = 0;
    for (;;)
    {
        bytesReceived = ::recv(socket,                        //_In_  SOCKET s
                               static_cast<char*>(buffer),    //_Out_ char*  buf
                               static_cast<int>(bytesToRead), //_In_  int    len
                               0);                            //_In_  int    flags
        if (bytesReceived >= 0 || errno != EINTR)
            break;
    }
    if (bytesReceived < 0)
        THROW_LAST_SYS_ERROR_WSA("recv");

    if (static_cast<size_t>(bytesReceived) > bytesToRead) //better safe than sorry
        throw SysError(formatSystemError("recv", L"", L"Buffer overflow."));

    return bytesReceived; //"zero indicates end of file"
}


size_t tryWriteSocket(SocketType socket, const void* buffer, size_t bytesToWrite) //throw SysError; may return short! CONTRACT: bytesToWrite > 0
{
    if (bytesToWrite == 0)
        throw std::logic_error("Contract violation! " + std::string(__FILE__) + ':' + numberTo<std::string>(__LINE__));

    int bytesWritten = 0;
    for (;;)
    {
        bytesWritten = ::send(socket,                           //_In_       SOCKET s
                              static_cast<const char*>(buffer), //_In_ const char*  buf
                              static_cast<int>(bytesToWrite),   //_In_       int    len
                              0);                               //_In_       int    flags
        if (bytesWritten >= 0 || errno != EINTR)
            break;
    }
    if (bytesWritten < 0)
        THROW_LAST_SYS_ERROR_WSA("send");
    if (bytesWritten > static_cast<int>(bytesToWrite))
        throw SysError(formatSystemError("send", L"", L"Buffer overflow."));
    if (bytesWritten == 0)
        throw SysError(formatSystemError("send", L"", L"Zero bytes processed."));

    return bytesWritten;
}
}


//initiate termination of connection by sending TCP FIN package
inline
void shutdownSocketSend(SocketType socket) //throw SysError
{
    if (::shutdown(socket, SHUT_WR) != 0)
        THROW_LAST_SYS_ERROR_WSA("shutdown");
}
}

#endif //SOCKET_H_23498325972583947678456437
bgstack15