1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
|
// *****************************************************************************
// * This file is part of the FreeFileSync project. It is distributed under *
// * GNU General Public License: https://www.gnu.org/licenses/gpl-3.0 *
// * Copyright (C) Zenju (zenju AT freefilesync DOT org) - All Rights Reserved *
// *****************************************************************************
#ifndef SOCKET_H_23498325972583947678456437
#define SOCKET_H_23498325972583947678456437
#include <zen/zstring.h>
#include "sys_error.h"
#include <unistd.h> //close
#include <sys/socket.h>
#include <netdb.h> //getaddrinfo
namespace zen
{
#define THROW_LAST_SYS_ERROR_WSA(functionName) \
do { const ErrorCode ecInternal = getLastError(); throw SysError(formatSystemError(functionName, ecInternal)); } while (false)
//patch up socket portability:
using SocketType = int;
const SocketType invalidSocket = -1;
inline void closeSocket(SocketType s) { ::close(s); }
//Winsock needs to be initialized before calling any of these functions! (WSAStartup/WSACleanup)
class Socket //throw SysError
{
public:
Socket(const Zstring& server, const Zstring& serviceName) //throw SysError
{
::addrinfo hints = {};
hints.ai_socktype = SOCK_STREAM; //we *do* care about this one!
hints.ai_flags = AI_ADDRCONFIG; //save a AAAA lookup on machines that can't use the returned data anyhow
::addrinfo* servinfo = nullptr;
ZEN_ON_SCOPE_EXIT(if (servinfo) ::freeaddrinfo(servinfo));
const int rcGai = ::getaddrinfo(server.c_str(), serviceName.c_str(), &hints, &servinfo);
if (rcGai != 0)
throw SysError(formatSystemError("getaddrinfo", replaceCpy(_("Error code %x"), L"%x", numberTo<std::wstring>(rcGai)), utfTo<std::wstring>(::gai_strerror(rcGai))));
if (!servinfo)
throw SysError(formatSystemError("getaddrinfo", L"", L"Empty server info."));
const auto getConnectedSocket = [](const auto& /*::addrinfo*/ ai)
{
SocketType testSocket = ::socket(ai.ai_family, //int socket_family
SOCK_CLOEXEC |
ai.ai_socktype, //int socket_type
ai.ai_protocol); //int protocol
if (testSocket == invalidSocket)
THROW_LAST_SYS_ERROR_WSA("socket");
ZEN_ON_SCOPE_FAIL(closeSocket(testSocket));
if (::connect(testSocket, ai.ai_addr, static_cast<int>(ai.ai_addrlen)) != 0)
THROW_LAST_SYS_ERROR_WSA("connect");
return testSocket;
};
std::optional<SysError> firstError;
for (const auto* /*::addrinfo*/ si = servinfo; si; si = si->ai_next)
try
{
socket_ = getConnectedSocket(*si); //throw SysError; pass ownership
return;
}
catch (const SysError& e) { if (!firstError) firstError = e; }
throw* firstError; //list was not empty, so there must have been an error!
}
~Socket() { closeSocket(socket_); }
SocketType get() const { return socket_; }
private:
Socket (const Socket&) = delete;
Socket& operator=(const Socket&) = delete;
SocketType socket_ = invalidSocket;
};
//more socket helper functions:
namespace
{
size_t tryReadSocket(SocketType socket, void* buffer, size_t bytesToRead) //throw SysError; may return short, only 0 means EOF!
{
if (bytesToRead == 0) //"read() with a count of 0 returns zero" => indistinguishable from end of file! => check!
throw std::logic_error("Contract violation! " + std::string(__FILE__) + ':' + numberTo<std::string>(__LINE__));
int bytesReceived = 0;
for (;;)
{
bytesReceived = ::recv(socket, //_In_ SOCKET s
static_cast<char*>(buffer), //_Out_ char* buf
static_cast<int>(bytesToRead), //_In_ int len
0); //_In_ int flags
if (bytesReceived >= 0 || errno != EINTR)
break;
}
if (bytesReceived < 0)
THROW_LAST_SYS_ERROR_WSA("recv");
if (static_cast<size_t>(bytesReceived) > bytesToRead) //better safe than sorry
throw SysError(formatSystemError("recv", L"", L"Buffer overflow."));
return bytesReceived; //"zero indicates end of file"
}
size_t tryWriteSocket(SocketType socket, const void* buffer, size_t bytesToWrite) //throw SysError; may return short! CONTRACT: bytesToWrite > 0
{
if (bytesToWrite == 0)
throw std::logic_error("Contract violation! " + std::string(__FILE__) + ':' + numberTo<std::string>(__LINE__));
int bytesWritten = 0;
for (;;)
{
bytesWritten = ::send(socket, //_In_ SOCKET s
static_cast<const char*>(buffer), //_In_ const char* buf
static_cast<int>(bytesToWrite), //_In_ int len
0); //_In_ int flags
if (bytesWritten >= 0 || errno != EINTR)
break;
}
if (bytesWritten < 0)
THROW_LAST_SYS_ERROR_WSA("send");
if (bytesWritten > static_cast<int>(bytesToWrite))
throw SysError(formatSystemError("send", L"", L"Buffer overflow."));
if (bytesWritten == 0)
throw SysError(formatSystemError("send", L"", L"Zero bytes processed."));
return bytesWritten;
}
}
//initiate termination of connection by sending TCP FIN package
inline
void shutdownSocketSend(SocketType socket) //throw SysError
{
if (::shutdown(socket, SHUT_WR) != 0)
THROW_LAST_SYS_ERROR_WSA("shutdown");
}
}
#endif //SOCKET_H_23498325972583947678456437
|