// Copyright (C) 2003 Davis E. King (davis@dlib.net), Miguel Grinberg
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SOCKETS_KERNEL_1_CPp_
#define DLIB_SOCKETS_KERNEL_1_CPp_
#include "../platform.h"
#ifdef WIN32
#include <winsock2.h>
#ifndef _WINSOCKAPI_
#define _WINSOCKAPI_ /* Prevent inclusion of winsock.h in windows.h */
#endif
#include "../windows_magic.h"
#include "sockets_kernel_1.h"
#include <windows.h>
#ifndef NI_MAXHOST
#define NI_MAXHOST 1025
#endif
// tell visual studio to link to the libraries we need if we are
// in fact using visual studio
#ifdef _MSC_VER
#pragma comment (lib, "ws2_32.lib")
#endif
#include "../assert.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class SOCKET_container
{
/*!
This object is just a wrapper around the SOCKET type. It exists
so that we can #include the windows.h and Winsock2.h header files
in this cpp file and not at all in the header file.
!*/
public:
SOCKET_container (
SOCKET s = INVALID_SOCKET
) : val(s) {}
SOCKET val;
operator SOCKET&() { return val; }
SOCKET_container& operator= (
const SOCKET& s
) { val = s; return *this; }
bool operator== (
const SOCKET& s
) const { return s == val; }
};
// ----------------------------------------------------------------------------------------
// stuff to ensure that WSAStartup() is always called before any sockets stuff is needed
namespace sockets_kernel_1_mutex
{
mutex startup_lock;
}
class sockets_startupdown
{
public:
sockets_startupdown();
~sockets_startupdown() { WSACleanup( ); }
};
sockets_startupdown::sockets_startupdown (
)
{
WSADATA wsaData;
WSAStartup (MAKEWORD(2,0), &wsaData);
}
void sockets_startup()
{
// mutex crap to make this function thread-safe
sockets_kernel_1_mutex::startup_lock.lock();
static sockets_startupdown a;
sockets_kernel_1_mutex::startup_lock.unlock();
}
// ----------------------------------------------------------------------------------------
// lookup functions
int
get_local_hostname (
std::string& hostname
)
{
// ensure that WSAStartup has been called and WSACleanup will eventually
// be called when program ends
sockets_startup();
try
{
char temp[NI_MAXHOST];
if (gethostname(temp,NI_MAXHOST) == SOCKET_ERROR )
{
return OTHER_ERROR;
}
hostname = temp;
}
catch (...)
{
return OTHER_ERROR;
}
return 0;
}
// -----------------
int
hostname_to_ip (
const std::string& hostname,
std::string& ip,
int n
)
{
// ensure that WSAStartup has been called and WSACleanup will eventually
// be called when program ends
sockets_startup();
try
{
// lock this mutex since gethostbyname isn't really thread safe
auto_mutex M(sockets_kernel_1_mutex::startup_lock);
// if no hostname was given then return error
if ( hostname.empty())
return OTHER_ERROR;
hostent* address;
address = gethostbyname(hostname.c_str());
if (address == 0)
{
return OTHER_ERROR;
}
// find the nth address
in_addr* addr = reinterpret_cast<in_addr*>(address->h_addr_list[0]);
for (int i = 1; i <= n; ++i)
{
addr = reinterpret_cast<in_addr*>(address->h_addr_list[i]);
// if there is no nth address then return error
if (addr == 0)
return OTHER_ERROR;
}
char* resolved_ip = inet_ntoa(*addr);
// check if inet_ntoa returned an error
if (resolved_ip == NULL)
{
return OTHER_ERROR;
}
ip.assign(resolved_ip);
}
catch(...)
{
return OTHER_ERROR;
}
return 0;
}
// -----------------
int
ip_to_hostname (
const std::string& ip,
std::string& hostname
)
{
// ensure that WSAStartup has been called and WSACleanup will eventually
// be called when program ends
sockets_startup();
try
{
// lock this mutex since gethostbyaddr isn't really thread safe
auto_mutex M(sockets_kernel_1_mutex::startup_lock);
// if no ip was given then return error
if (ip.empty())
return OTHER_ERROR;
hostent* address;
unsigned long ipnum = inet_addr(ip.c_str());
// if inet_addr couldn't convert ip then return an error
if (ipnum == INADDR_NONE)
{
return OTHER_ERROR;
}
address = gethostbyaddr(reinterpret_cast<char*>(&ipnum),4,AF_INET);
// check if gethostbyaddr returned an error
if (address == 0)
{
return OTHER_ERROR;
}
hostname.assign(address->h_name);
}
catch (...)
{
return OTHER_ERROR;
}
return 0;
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// connection object
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
connection::
connection(
SOCKET_container sock,
unsigned short foreign_port,
const std::string& foreign_ip,
unsigned short local_port,
const std::string& local_ip
) :
user_data(0),
connection_socket(*(new SOCKET_container())),
connection_foreign_port(foreign_port),
connection_foreign_ip(foreign_ip),
connection_local_port(local_port),
connection_local_ip(local_ip),
sd(false),
sdo(false),
sdr(0)
{
connection_socket = sock;
}
// ----------------------------------------------------------------------------------------
connection::
~connection (
)
{
if (connection_socket != INVALID_SOCKET)
closesocket(connection_socket);
delete &connection_socket;
}
// ----------------------------------------------------------------------------------------
int connection::
disable_nagle()
{
int flag = 1;
int status = setsockopt( connection_socket, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(flag) );
if (status == SOCKET_ERROR)
return OTHER_ERROR;
else
return 0;
}
// ----------------------------------------------------------------------------------------
long connection::
write (
const char* buf,
long num
)
{
const long old_num = num;
long status;
const long max_send_length = 1024*1024*100;
while (num > 0)
{
// Make sure to cap the max value num can take on so that if it is
// really large (it might be big on 64bit platforms) so that the OS
// can't possibly get upset about it being large.
const long length = std::min(max_send_length, num);
if ( (status = send(connection_socket,buf,length,0)) == SOCKET_ERROR)
{
if (sdo_called())
return SHUTDOWN;
else
return OTHER_ERROR;
}
num -= status;
buf += status;
}
return old_num;
}
// ----------------------------------------------------------------------------------------
long connection::
read (
char* buf,
long num
)
{
const long max_recv_length = 1024*1024*100;
// Make sure to cap the max value num can take on so that if it is
// really large (it might be big on 64bit platforms) so that the OS
// can't possibly get upset about it being large.
const long length = std::min(max_recv_length, num);
long status = recv(connection_socket,buf,length,0);
if (status == SOCKET_ERROR)
{
// if this error is the result of a shutdown call then return SHUTDOWN
if (sd_called())
return SHUTDOWN;
else
return OTHER_ERROR;
}
else if (status == 0 && sd_called())
{
return SHUTDOWN;
}
return status;
}
// ----------------------------------------------------------------------------------------
long connection::
read (
char* buf,
long num,
unsigned long timeout
)
{
if (readable(timeout) == false)
return TIMEOUT;
const long max_recv_length = 1024*1024*100;
// Make sure to cap the max value num can take on so that if it is
// really large (it might be big on 64bit platforms) so that the OS
// can't possibly get upset about it being large.
const long length = std::min(max_recv_length, num);
long status = recv(connection_socket,buf,length,0);
if (status == SOCKET_ERROR)
{
// if this error is the result of a shutdown call then return SHUTDOWN
if (sd_called())
return SHUTDOWN;
else
return OTHER_ERROR;
}
else if (status == 0 && sd_called())
{
return SHUTDOWN;
}
return status;
}
// ----------------------------------------------------------------------------------------
bool connection::
readable (
unsigned long timeout
) const
{
fd_set read_set;
// initialize read_set
FD_ZERO(&read_set);
// add the listening socket to read_set
FD_SET(connection_socket, &read_set);
// setup a timeval structure
timeval time_to_wait;
time_to_wait.tv_sec = static_cast<long>(timeout/1000);
time_to_wait.tv_usec = static_cast<long>((timeout%1000)*1000);
// wait on select
int status = select(0,&read_set,0,0,&time_to_wait);
// if select timed out or there was an error
if (status <= 0)
return false;
// data is ready to be read
return true;
}
// ----------------------------------------------------------------------------------------
int connection::
shutdown_outgoing (
)
{
sd_mutex.lock();
if (sdo || sd)
{
sd_mutex.unlock();
return sdr;
}
sdo = true;
sdr = ::shutdown(connection_socket,SD_SEND);
// convert -1 error code into the OTHER_ERROR error code
if (sdr == -1)
sdr = OTHER_ERROR;
int temp = sdr;
sd_mutex.unlock();
return temp;
}
// ----------------------------------------------------------------------------------------
int connection::
shutdown (
)
{
sd_mutex.lock();
if (sd)
{
sd_mutex.unlock();
return sdr;
}
sd = true;
SOCKET stemp = connection_socket;
connection_socket = INVALID_SOCKET;
sdr = closesocket(stemp);
// convert SOCKET_ERROR error code into the OTHER_ERROR error code
if (sdr == SOCKET_ERROR)
sdr = OTHER_ERROR;
int temp = sdr;
sd_mutex.unlock();
return temp;
}
// ----------------------------------------------------------------------------------------
connection::socket_descriptor_type connection::
get_socket_descriptor (
) const
{
return connection_socket.val;
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// listener object
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
listener::
listener(
SOCKET_container sock,
unsigned short port,
const std::string& ip
) :
listening_socket(*(new SOCKET_container)),
listening_port(port),
listening_ip(ip),
inaddr_any(listening_ip.empty())
{
listening_socket = sock;
}
// ----------------------------------------------------------------------------------------
listener::
~listener (
)
{
closesocket(listening_socket);
delete &listening_socket;
}
// ----------------------------------------------------------------------------------------
int listener::
accept (
std::unique_ptr<connection>& new_connection,
unsigned long timeout
)
{
new_connection.reset(0);
connection* con;
int status = this->accept(con, timeout);
if (status == 0)
new_connection.reset(con);
return status;
}
// ----------------------------------------------------------------------------------------
int listener::
accept (
connection*& new_connection,
unsigned long timeout
)
{
SOCKET incoming;
sockaddr_in incomingAddr;
int length = sizeof(sockaddr_in);
// implement timeout with select if timeout is > 0
if (timeout > 0)
{
fd_set read_set;
// initialize read_set
FD_ZERO(&read_set);
// add the listening socket to read_set
FD_SET(listening_socket, &read_set);
// setup a timeval structure
timeval time_to_wait;
time_to_wait.tv_sec = static_cast<long>(timeout/1000);
time_to_wait.tv_usec = static_cast<long>((timeout%1000)*1000);
// wait on select
int status = select(0,&read_set,0,0,&time_to_wait);
// if select timed out
if (status == 0)
return TIMEOUT;
// if select returned an error
if (status == SOCKET_ERROR)
return OTHER_ERROR;
}
// call accept to get a new connection
incoming=::accept(listening_socket,reinterpret_cast<sockaddr*>(&incomingAddr),&length);
// if there was an error return OTHER_ERROR
if ( incoming == INVALID_SOCKET )
return OTHER_ERROR;
// get the port of the foreign host into foreign_port
int foreign_port = ntohs(incomingAddr.sin_port);
// get the IP of the foreign host into foreign_ip
std::string foreign_ip;
{
char* foreign_ip_temp = inet_ntoa(incomingAddr.sin_addr);
// check if inet_ntoa() returned an error
if (foreign_ip_temp == NULL)
{
closesocket(incoming);
return OTHER_ERROR;
}
foreign_ip.assign(foreign_ip_temp);
}
// get the local ip
std::string local_ip;
if (inaddr_any == true)
{
sockaddr_in local_info;
length = sizeof(sockaddr_in);
// get the local sockaddr_in structure associated with this new connection
if ( getsockname (
incoming,
reinterpret_cast<sockaddr*>(&local_info),
&length
) == SOCKET_ERROR
)
{ // an error occurred
closesocket(incoming);
return OTHER_ERROR;
}
char* temp = inet_ntoa(local_info.sin_addr);
// check if inet_ntoa() returned an error
if (temp == NULL)
{
closesocket(incoming);
return OTHER_ERROR;
}
local_ip.assign(temp);
}
else
{
local_ip = listening_ip;
}
// set the SO_OOBINLINE option
int flag_value = 1;
if (setsockopt(incoming,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast<const char*>(&flag_value),sizeof(int)) == SOCKET_ERROR )
{
closesocket(incoming);
return OTHER_ERROR;
}
// make a new connection object for this new connection
try
{
new_connection = new connection (
incoming,
foreign_port,
foreign_ip,
listening_port,
local_ip
);
}
catch (...) { closesocket(incoming); return OTHER_ERROR; }
return 0;
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// socket creation functions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
int create_listener (
std::unique_ptr<listener>& new_listener,
unsigned short port,
const std::string& ip
)
{
new_listener.reset();
listener* temp;
int status = create_listener(temp,port,ip);
if (status == 0)
new_listener.reset(temp);
return status;
}
int create_listener (
listener*& new_listener,
unsigned short port,
const std::string& ip
)
{
// ensure that WSAStartup has been called and WSACleanup will eventually
// be called when program ends
sockets_startup();
sockaddr_in sa; // local socket structure
ZeroMemory(&sa,sizeof(sockaddr_in)); // initialize sa
SOCKET sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket
// if socket() returned an error then return OTHER_ERROR
if (sock == INVALID_SOCKET )
{
return OTHER_ERROR;
}
// set the local socket structure
sa.sin_family = AF_INET;
sa.sin_port = htons(port);
if (ip.empty())
{
// if the listener should listen on any IP
sa.sin_addr.S_un.S_addr = htons(INADDR_ANY);
}
else
{
// if there is a specific ip to listen on
sa.sin_addr.S_un.S_addr = inet_addr(ip.c_str());
// if inet_addr couldn't convert the ip then return an error
if ( sa.sin_addr.S_un.S_addr == INADDR_NONE )
{
closesocket(sock);
return OTHER_ERROR;
}
}
// set the SO_REUSEADDR option
int flag_value = 1;
setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,reinterpret_cast<const char*>(&flag_value),sizeof(int));
// bind the new socket to the requested port and ip
if (bind(sock,reinterpret_cast<sockaddr*>(&sa),sizeof(sockaddr_in))==SOCKET_ERROR)
{
const int err = WSAGetLastError();
// if there was an error
closesocket(sock);
// if the port is already bound then return PORTINUSE
if (err == WSAEADDRINUSE)
return PORTINUSE;
else
return OTHER_ERROR;
}
// tell the new socket to listen
if ( listen(sock,SOMAXCONN) == SOCKET_ERROR)
{
const int err = WSAGetLastError();
// if there was an error return OTHER_ERROR
closesocket(sock);
// if the port is already bound then return PORTINUSE
if (err == WSAEADDRINUSE)
return PORTINUSE;
else
return OTHER_ERROR;
}
// determine the port used if necessary
if (port == 0)
{
sockaddr_in local_info;
int length = sizeof(sockaddr_in);
if ( getsockname (
sock,
reinterpret_cast<sockaddr*>(&local_info),
&length
) == SOCKET_ERROR
)
{
closesocket(sock);
return OTHER_ERROR;
}
port = ntohs(local_info.sin_port);
}
// initialize a listener object on the heap with the new socket
try { new_listener = new listener(sock,port,ip); }
catch(...) { closesocket(sock); return OTHER_ERROR; }
return 0;
}
// ----------------------------------------------------------------------------------------
int create_connection (
std::unique_ptr<connection>& new_connection,
unsigned short foreign_port,
const std::string& foreign_ip,
unsigned short local_port,
const std::string& local_ip
)
{
new_connection.reset();
connection* temp;
int status = create_connection(temp,foreign_port, foreign_ip, local_port, local_ip);
if (status == 0)
new_connection.reset(temp);
return status;
}
int create_connection (
connection*& new_connection,
unsigned short foreign_port,
const std::string& foreign_ip,
unsigned short local_port,
const std::string& local_ip
)
{
// ensure that WSAStartup has been called and WSACleanup
// will eventually be called when program ends
sockets_startup();
sockaddr_in local_sa; // local socket structure
sockaddr_in foreign_sa; // foreign socket structure
ZeroMemory(&local_sa,sizeof(sockaddr_in)); // initialize local_sa
ZeroMemory(&foreign_sa,sizeof(sockaddr_in)); // initialize foreign_sa
int length;
SOCKET sock = socket (AF_INET, SOCK_STREAM, 0); // get a new socket
// if socket() returned an error then return OTHER_ERROR
if (sock == INVALID_SOCKET )
{
return OTHER_ERROR;
}
// set the foreign socket structure
foreign_sa.sin_family = AF_INET;
foreign_sa.sin_port = htons(foreign_port);
foreign_sa.sin_addr.S_un.S_addr = inet_addr(foreign_ip.c_str());
// if inet_addr couldn't convert the ip then return an error
if ( foreign_sa.sin_addr.S_un.S_addr == INADDR_NONE )
{
closesocket(sock);
return OTHER_ERROR;
}
// set up the local socket structure
local_sa.sin_family = AF_INET;
// set the local ip
if (local_ip.empty())
{
// if the listener should listen on any IP
local_sa.sin_addr.S_un.S_addr = htons(INADDR_ANY);
}
else
{
// if there is a specific ip to listen on
local_sa.sin_addr.S_un.S_addr = inet_addr(local_ip.c_str());
// if inet_addr couldn't convert the ip then return an error
if (local_sa.sin_addr.S_un.S_addr == INADDR_NONE)
{
closesocket(sock);
return OTHER_ERROR;
}
}
// set the local port
local_sa.sin_port = htons(local_port);
// bind the new socket to the requested local port and local ip
if ( bind (
sock,
reinterpret_cast<sockaddr*>(&local_sa),
sizeof(sockaddr_in)
) == SOCKET_ERROR
)
{
const int err = WSAGetLastError();
// if there was an error
closesocket(sock);
// if the port is already bound then return PORTINUSE
if (err == WSAEADDRINUSE)
return PORTINUSE;
else
return OTHER_ERROR;
}
// connect the socket
if (connect (
sock,
reinterpret_cast<sockaddr*>(&foreign_sa),
sizeof(sockaddr_in)
) == SOCKET_ERROR
)
{
const int err = WSAGetLastError();
closesocket(sock);
// if the port is already bound then return PORTINUSE
if (err == WSAEADDRINUSE)
return PORTINUSE;
else
return OTHER_ERROR;
}
// determine the local port and IP and store them in used_local_ip
// and used_local_port
int used_local_port;
std::string used_local_ip;
sockaddr_in local_info;
if (local_port == 0)
{
length = sizeof(sockaddr_in);
if (getsockname (
sock,
reinterpret_cast<sockaddr*>(&local_info),
&length
) == SOCKET_ERROR
)
{
closesocket(sock);
return OTHER_ERROR;
}
used_local_port = ntohs(local_info.sin_port);
}
else
{
used_local_port = local_port;
}
// determine real local ip
if (local_ip.empty())
{
// if local_port is not 0 then we must fill the local_info structure
if (local_port != 0)
{
length = sizeof(sockaddr_in);
if ( getsockname (
sock,
reinterpret_cast<sockaddr*>(&local_info),
&length
) == SOCKET_ERROR
)
{
closesocket(sock);
return OTHER_ERROR;
}
}
char* temp = inet_ntoa(local_info.sin_addr);
// check if inet_ntoa returned an error
if (temp == NULL)
{
closesocket(sock);
return OTHER_ERROR;
}
used_local_ip.assign(temp);
}
else
{
used_local_ip = local_ip;
}
// set the SO_OOBINLINE option
int flag_value = 1;
if (setsockopt(sock,SOL_SOCKET,SO_OOBINLINE,reinterpret_cast<const char*>(&flag_value),sizeof(int)) == SOCKET_ERROR )
{
closesocket(sock);
return OTHER_ERROR;
}
// initialize a connection object on the heap with the new socket
try
{
new_connection = new connection (
sock,
foreign_port,
foreign_ip,
used_local_port,
used_local_ip
);
}
catch(...) {closesocket(sock); return OTHER_ERROR; }
return 0;
}
// ----------------------------------------------------------------------------------------
}
#endif // WIN32
#endif // DLIB_SOCKETS_KERNEL_1_CPp_