summaryrefslogtreecommitdiff
path: root/modules/extra/m_ssl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'modules/extra/m_ssl.cpp')
-rw-r--r--modules/extra/m_ssl.cpp152
1 files changed, 82 insertions, 70 deletions
diff --git a/modules/extra/m_ssl.cpp b/modules/extra/m_ssl.cpp
index da6361769..30e548150 100644
--- a/modules/extra/m_ssl.cpp
+++ b/modules/extra/m_ssl.cpp
@@ -28,9 +28,16 @@ class MySSLService : public SSLService
class SSLSocketIO : public SocketIO
{
+ /** Check whether this socket has a pending connect() or accept()
+ * @return 0 if neither, -1 if connect/accept fails, -2 to wait more
+ */
+ int CheckState();
+
public:
/* The SSL socket for this socket */
SSL *sslsock;
+ /* -1 if not, 0 if waiting, 1 if true */
+ int connected, accepted;
/** Constructor
*/
@@ -42,27 +49,27 @@ class SSLSocketIO : public SocketIO
* @param sz How much to read
* @return Number of bytes received
*/
- int Recv(Socket *s, char *buf, size_t sz) const;
+ int Recv(Socket *s, char *buf, size_t sz);
/** Really write something to the socket
* @param s The socket
* @param buf What to write
* @return Number of bytes written
*/
- int Send(Socket *s, const Anope::string &buf) const;
+ int Send(Socket *s, const Anope::string &buf);
/** Accept a connection from a socket
* @param s The socket
+ * @return The new socket
*/
- void Accept(ListenSocket *s);
+ ClientSocket *Accept(ListenSocket *s);
/** Connect the socket
* @param s THe socket
* @param target IP to connect to
* @param port to connect to
- * @param bindip IP to bind to, if any
*/
- void Connect(ConnectionSocket *s, const Anope::string &target, int port, const Anope::string &bindip = "");
+ void Connect(ConnectionSocket *s, const Anope::string &target, int port);
/** Called when the socket is destructing
*/
@@ -144,41 +151,27 @@ class SSLModule : public Module
~SSLModule()
{
+ for (std::map<int, Socket *>::const_iterator it = SocketEngine::Sockets.begin(), it_end = SocketEngine::Sockets.end(); it != it_end;)
+ {
+ Socket *s = it->second;
+ ++it;
+
+ if (dynamic_cast<SSLSocketIO *>(s->IO))
+ delete s;
+ }
+
SSL_CTX_free(client_ctx);
SSL_CTX_free(server_ctx);
}
- EventReturn OnPreServerConnect(Uplink *u, int Number)
+ void OnPreServerConnect()
{
ConfigReader config;
- if (config.ReadFlag("uplink", "ssl", "no", Number - 1))
+ if (config.ReadFlag("uplink", "ssl", "no", CurrentUplink))
{
- DNSRecord req = DNSManager::BlockingQuery(uplink_server->host, uplink_server->ipv6 ? DNS_QUERY_AAAA : DNS_QUERY_A);
-
- if (!req)
- Log() << "Unable to connect to server " << uplink_server->host << ":" << uplink_server->port << " using SSL: Invalid hostname/IP";
- else
- {
- try
- {
- new UplinkSocket(uplink_server->ipv6);
- this->service.Init(UplinkSock);
- UplinkSock->Connect(req.result, uplink_server->port, Config->LocalHost);
-
- Log() << "Connected to server " << Number << " (" << u->host << ":" << u->port << ") with SSL";
- return EVENT_ALLOW;
- }
- catch (const SocketException &ex)
- {
- Log() << "Unable to connect with SSL to server " << Number << " (" << u->host << ":" << u->port << "), " << ex.GetReason();
- }
- }
-
- return EVENT_STOP;
+ this->service.Init(UplinkSock);
}
-
- return EVENT_CONTINUE;
}
};
@@ -194,39 +187,67 @@ void MySSLService::Init(Socket *s)
s->IO = new SSLSocketIO();
}
-SSLSocketIO::SSLSocketIO()
+int SSLSocketIO::CheckState()
+{
+ if (this->connected == 0 || this->accepted == 0)
+ {
+ int ret;
+ if (this->connected == 0)
+ ret = SSL_connect(this->sslsock);
+ else if (this->accepted == 0)
+ ret = SSL_accept(this->sslsock);
+ if (ret <= 0)
+ {
+ int error = SSL_get_error(this->sslsock, ret);
+
+ if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE))
+ // Wait more
+ return -2;
+ return -1;
+ }
+
+ if (this->connected == 0)
+ this->connected = 1;
+ else if (this->accepted == 0)
+ this->accepted = 1;
+ }
+
+ return 0;
+}
+
+SSLSocketIO::SSLSocketIO() : connected(-1), accepted(-1)
{
this->sslsock = NULL;
}
-int SSLSocketIO::Recv(Socket *s, char *buf, size_t sz) const
+int SSLSocketIO::Recv(Socket *s, char *buf, size_t sz)
{
- int i = SSL_read(this->sslsock, buf, sz);
+ int i = this->CheckState();
+ if (i < 0)
+ return i;
+
+ i = SSL_read(this->sslsock, buf, sz);
TotalRead += i;
return i;
}
-int SSLSocketIO::Send(Socket *s, const Anope::string &buf) const
+int SSLSocketIO::Send(Socket *s, const Anope::string &buf)
{
- int i = SSL_write(this->sslsock, buf.c_str(), buf.length());
+ int i = this->CheckState();
+ if (i < 0)
+ return i;
+
+ i = SSL_write(this->sslsock, buf.c_str(), buf.length());
TotalWritten += i;
return i;
}
-void SSLSocketIO::Accept(ListenSocket *s)
+ClientSocket *SSLSocketIO::Accept(ListenSocket *s)
{
- sockaddrs conaddr;
-
- socklen_t size = conaddr.size();
- int newsock = accept(s->GetFD(), &conaddr.sa, &size);
-
-#ifndef INVALID_SOCKET
-# define INVALID_SOCKET -1
-#endif
- if (newsock <= 0 || newsock == INVALID_SOCKET)
- throw SocketException("Unable to accept SSL socket: " + Anope::LastError());
-
- ClientSocket *newsocket = s->OnAccept(newsock, conaddr);
+ if (s->IO == &normalSocketIO)
+ throw SocketException("Attempting to accept on uninitialized socket with SSL");
+
+ ClientSocket *newsocket = normalSocketIO.Accept(s);
me->service.Init(newsocket);
SSLSocketIO *IO = debug_cast<SSLSocketIO *>(newsocket->IO);
@@ -236,25 +257,22 @@ void SSLSocketIO::Accept(ListenSocket *s)
SSL_set_accept_state(IO->sslsock);
- if (!SSL_set_fd(IO->sslsock, newsock))
+ if (!SSL_set_fd(IO->sslsock, newsocket->GetFD()))
throw SocketException("Unable to set SSL fd");
- int ret = SSL_accept(IO->sslsock);
- if (ret <= 0)
- {
- int error = SSL_get_error(IO->sslsock, ret);
-
- if (ret != -1 || (error != SSL_ERROR_WANT_READ && error != SSL_ERROR_WANT_READ))
- throw SocketException("Unable to accept new SSL connection: " + Anope::string(ERR_error_string(ERR_get_error(), NULL)));
- }
+ IO->accepted = 0;
+ if (this->CheckState() == -1)
+ throw SocketException("Unable to accept new SSL connection: " + Anope::string(ERR_error_string(ERR_get_error(), NULL)));
+
+ return newsocket;
}
-void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &TargetHost, int Port, const Anope::string &BindHost)
+void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &target, int port)
{
if (s->IO == &normalSocketIO)
- throw SocketException("Attempting to connect uninitialized socket with SQL");
+ throw SocketException("Attempting to connect uninitialized socket with SSL");
- normalSocketIO.Connect(s, TargetHost, Port, BindHost);
+ normalSocketIO.Connect(s, target, port);
SSLSocketIO *IO = debug_cast<SSLSocketIO *>(s->IO);
@@ -265,15 +283,9 @@ void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &TargetHost,
if (!SSL_set_fd(IO->sslsock, s->GetFD()))
throw SocketException("Unable to set SSL fd");
- int ret = SSL_connect(IO->sslsock);
-
- if (ret <= 0)
- {
- int error = SSL_get_error(IO->sslsock, ret);
-
- if (ret != -1 || (error != SSL_ERROR_WANT_READ && error != SSL_ERROR_WANT_READ))
- throw SocketException("Unable to connect to server: " + Anope::string(ERR_error_string(ERR_get_error(), NULL)));
- }
+ IO->connected = 0;
+ if (this->CheckState() == -1)
+ throw SocketException("Unable to connect to server: " + Anope::string(ERR_error_string(ERR_get_error(), NULL)));
}
void SSLSocketIO::Destroy()