diff options
Diffstat (limited to 'modules/extra/m_ssl.cpp')
-rw-r--r-- | modules/extra/m_ssl.cpp | 152 |
1 files changed, 82 insertions, 70 deletions
diff --git a/modules/extra/m_ssl.cpp b/modules/extra/m_ssl.cpp index da6361769..30e548150 100644 --- a/modules/extra/m_ssl.cpp +++ b/modules/extra/m_ssl.cpp @@ -28,9 +28,16 @@ class MySSLService : public SSLService class SSLSocketIO : public SocketIO { + /** Check whether this socket has a pending connect() or accept() + * @return 0 if neither, -1 if connect/accept fails, -2 to wait more + */ + int CheckState(); + public: /* The SSL socket for this socket */ SSL *sslsock; + /* -1 if not, 0 if waiting, 1 if true */ + int connected, accepted; /** Constructor */ @@ -42,27 +49,27 @@ class SSLSocketIO : public SocketIO * @param sz How much to read * @return Number of bytes received */ - int Recv(Socket *s, char *buf, size_t sz) const; + int Recv(Socket *s, char *buf, size_t sz); /** Really write something to the socket * @param s The socket * @param buf What to write * @return Number of bytes written */ - int Send(Socket *s, const Anope::string &buf) const; + int Send(Socket *s, const Anope::string &buf); /** Accept a connection from a socket * @param s The socket + * @return The new socket */ - void Accept(ListenSocket *s); + ClientSocket *Accept(ListenSocket *s); /** Connect the socket * @param s THe socket * @param target IP to connect to * @param port to connect to - * @param bindip IP to bind to, if any */ - void Connect(ConnectionSocket *s, const Anope::string &target, int port, const Anope::string &bindip = ""); + void Connect(ConnectionSocket *s, const Anope::string &target, int port); /** Called when the socket is destructing */ @@ -144,41 +151,27 @@ class SSLModule : public Module ~SSLModule() { + for (std::map<int, Socket *>::const_iterator it = SocketEngine::Sockets.begin(), it_end = SocketEngine::Sockets.end(); it != it_end;) + { + Socket *s = it->second; + ++it; + + if (dynamic_cast<SSLSocketIO *>(s->IO)) + delete s; + } + SSL_CTX_free(client_ctx); SSL_CTX_free(server_ctx); } - EventReturn OnPreServerConnect(Uplink *u, int Number) + void OnPreServerConnect() { ConfigReader config; - if (config.ReadFlag("uplink", "ssl", "no", Number - 1)) + if (config.ReadFlag("uplink", "ssl", "no", CurrentUplink)) { - DNSRecord req = DNSManager::BlockingQuery(uplink_server->host, uplink_server->ipv6 ? DNS_QUERY_AAAA : DNS_QUERY_A); - - if (!req) - Log() << "Unable to connect to server " << uplink_server->host << ":" << uplink_server->port << " using SSL: Invalid hostname/IP"; - else - { - try - { - new UplinkSocket(uplink_server->ipv6); - this->service.Init(UplinkSock); - UplinkSock->Connect(req.result, uplink_server->port, Config->LocalHost); - - Log() << "Connected to server " << Number << " (" << u->host << ":" << u->port << ") with SSL"; - return EVENT_ALLOW; - } - catch (const SocketException &ex) - { - Log() << "Unable to connect with SSL to server " << Number << " (" << u->host << ":" << u->port << "), " << ex.GetReason(); - } - } - - return EVENT_STOP; + this->service.Init(UplinkSock); } - - return EVENT_CONTINUE; } }; @@ -194,39 +187,67 @@ void MySSLService::Init(Socket *s) s->IO = new SSLSocketIO(); } -SSLSocketIO::SSLSocketIO() +int SSLSocketIO::CheckState() +{ + if (this->connected == 0 || this->accepted == 0) + { + int ret; + if (this->connected == 0) + ret = SSL_connect(this->sslsock); + else if (this->accepted == 0) + ret = SSL_accept(this->sslsock); + if (ret <= 0) + { + int error = SSL_get_error(this->sslsock, ret); + + if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) + // Wait more + return -2; + return -1; + } + + if (this->connected == 0) + this->connected = 1; + else if (this->accepted == 0) + this->accepted = 1; + } + + return 0; +} + +SSLSocketIO::SSLSocketIO() : connected(-1), accepted(-1) { this->sslsock = NULL; } -int SSLSocketIO::Recv(Socket *s, char *buf, size_t sz) const +int SSLSocketIO::Recv(Socket *s, char *buf, size_t sz) { - int i = SSL_read(this->sslsock, buf, sz); + int i = this->CheckState(); + if (i < 0) + return i; + + i = SSL_read(this->sslsock, buf, sz); TotalRead += i; return i; } -int SSLSocketIO::Send(Socket *s, const Anope::string &buf) const +int SSLSocketIO::Send(Socket *s, const Anope::string &buf) { - int i = SSL_write(this->sslsock, buf.c_str(), buf.length()); + int i = this->CheckState(); + if (i < 0) + return i; + + i = SSL_write(this->sslsock, buf.c_str(), buf.length()); TotalWritten += i; return i; } -void SSLSocketIO::Accept(ListenSocket *s) +ClientSocket *SSLSocketIO::Accept(ListenSocket *s) { - sockaddrs conaddr; - - socklen_t size = conaddr.size(); - int newsock = accept(s->GetFD(), &conaddr.sa, &size); - -#ifndef INVALID_SOCKET -# define INVALID_SOCKET -1 -#endif - if (newsock <= 0 || newsock == INVALID_SOCKET) - throw SocketException("Unable to accept SSL socket: " + Anope::LastError()); - - ClientSocket *newsocket = s->OnAccept(newsock, conaddr); + if (s->IO == &normalSocketIO) + throw SocketException("Attempting to accept on uninitialized socket with SSL"); + + ClientSocket *newsocket = normalSocketIO.Accept(s); me->service.Init(newsocket); SSLSocketIO *IO = debug_cast<SSLSocketIO *>(newsocket->IO); @@ -236,25 +257,22 @@ void SSLSocketIO::Accept(ListenSocket *s) SSL_set_accept_state(IO->sslsock); - if (!SSL_set_fd(IO->sslsock, newsock)) + if (!SSL_set_fd(IO->sslsock, newsocket->GetFD())) throw SocketException("Unable to set SSL fd"); - int ret = SSL_accept(IO->sslsock); - if (ret <= 0) - { - int error = SSL_get_error(IO->sslsock, ret); - - if (ret != -1 || (error != SSL_ERROR_WANT_READ && error != SSL_ERROR_WANT_READ)) - throw SocketException("Unable to accept new SSL connection: " + Anope::string(ERR_error_string(ERR_get_error(), NULL))); - } + IO->accepted = 0; + if (this->CheckState() == -1) + throw SocketException("Unable to accept new SSL connection: " + Anope::string(ERR_error_string(ERR_get_error(), NULL))); + + return newsocket; } -void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &TargetHost, int Port, const Anope::string &BindHost) +void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &target, int port) { if (s->IO == &normalSocketIO) - throw SocketException("Attempting to connect uninitialized socket with SQL"); + throw SocketException("Attempting to connect uninitialized socket with SSL"); - normalSocketIO.Connect(s, TargetHost, Port, BindHost); + normalSocketIO.Connect(s, target, port); SSLSocketIO *IO = debug_cast<SSLSocketIO *>(s->IO); @@ -265,15 +283,9 @@ void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &TargetHost, if (!SSL_set_fd(IO->sslsock, s->GetFD())) throw SocketException("Unable to set SSL fd"); - int ret = SSL_connect(IO->sslsock); - - if (ret <= 0) - { - int error = SSL_get_error(IO->sslsock, ret); - - if (ret != -1 || (error != SSL_ERROR_WANT_READ && error != SSL_ERROR_WANT_READ)) - throw SocketException("Unable to connect to server: " + Anope::string(ERR_error_string(ERR_get_error(), NULL))); - } + IO->connected = 0; + if (this->CheckState() == -1) + throw SocketException("Unable to connect to server: " + Anope::string(ERR_error_string(ERR_get_error(), NULL))); } void SSLSocketIO::Destroy() |