diff options
Diffstat (limited to 'modules/extra/m_ssl.cpp')
-rw-r--r-- | modules/extra/m_ssl.cpp | 202 |
1 files changed, 114 insertions, 88 deletions
diff --git a/modules/extra/m_ssl.cpp b/modules/extra/m_ssl.cpp index 41bc49993..5143d964a 100644 --- a/modules/extra/m_ssl.cpp +++ b/modules/extra/m_ssl.cpp @@ -31,8 +31,6 @@ class SSLSocketIO : public SocketIO public: /* The SSL socket for this socket */ SSL *sslsock; - /* -1 if not, 0 if waiting, 1 if true */ - int connected, accepted; /** Constructor */ @@ -46,12 +44,12 @@ class SSLSocketIO : public SocketIO */ int Recv(Socket *s, char *buf, size_t sz); - /** Really write something to the socket - * @param s The socket - * @param buf What to write - * @return Number of bytes written + /** Write something to the socket + * @param s The socket + * @param buf The data to write + * @param size The length of the data */ - int Send(Socket *s, const Anope::string &buf); + int Send(Socket *s, const char *buf, size_t sz); /** Accept a connection from a socket * @param s The socket @@ -59,11 +57,11 @@ class SSLSocketIO : public SocketIO */ ClientSocket *Accept(ListenSocket *s); - /** Check if a connection has been accepted - * @param s The client socket - * @return -1 on error, 0 to wait, 1 on success + /** Finished accepting a connection from a socket + * @param s The socket + * @return SF_ACCEPTED if accepted, SF_ACCEPTING if still in process, SF_DEAD on error */ - int Accepted(ClientSocket *cs); + SocketFlag FinishAccept(ClientSocket *cs); /** Connect the socket * @param s THe socket @@ -72,11 +70,11 @@ class SSLSocketIO : public SocketIO */ void Connect(ConnectionSocket *s, const Anope::string &target, int port); - /** Check if this socket is connected + /** Called to potentially finish a pending connection * @param s The socket - * @return -1 for error, 0 for wait, 1 for connected + * @return SF_CONNECTED on success, SF_CONNECTING if still pending, and SF_DEAD on error. */ - int Connected(ConnectionSocket *s); + SocketFlag FinishConnect(ConnectionSocket *s); /** Called when the socket is destructing */ @@ -194,7 +192,7 @@ void MySSLService::Init(Socket *s) s->IO = new SSLSocketIO(); } -SSLSocketIO::SSLSocketIO() : connected(-1), accepted(-1) +SSLSocketIO::SSLSocketIO() { this->sslsock = NULL; } @@ -206,9 +204,9 @@ int SSLSocketIO::Recv(Socket *s, char *buf, size_t sz) return i; } -int SSLSocketIO::Send(Socket *s, const Anope::string &buf) +int SSLSocketIO::Send(Socket *s, const char *buf, size_t sz) { - int i = SSL_write(this->sslsock, buf.c_str(), buf.length()); + int i = SSL_write(this->sslsock, buf, sz); TotalWritten += i; return i; } @@ -217,8 +215,20 @@ ClientSocket *SSLSocketIO::Accept(ListenSocket *s) { if (s->IO == &normalSocketIO) throw SocketException("Attempting to accept on uninitialized socket with SSL"); - - ClientSocket *newsocket = normalSocketIO.Accept(s); + + sockaddrs conaddr; + + socklen_t size = sizeof(conaddr); + int newsock = accept(s->GetFD(), &conaddr.sa, &size); + +#ifndef INVALID_SOCKET + const int INVALID_SOCKET = -1; +#endif + + if (newsock < 0 || newsock == INVALID_SOCKET) + throw SocketException("Unable to accept connection: " + Anope::LastError()); + + ClientSocket *newsocket = s->OnAccept(newsock, conaddr); me->service.Init(newsocket); SSLSocketIO *IO = debug_cast<SSLSocketIO *>(newsocket->IO); @@ -231,47 +241,47 @@ ClientSocket *SSLSocketIO::Accept(ListenSocket *s) if (!SSL_set_fd(IO->sslsock, newsocket->GetFD())) throw SocketException("Unable to set SSL fd"); - int ret = SSL_accept(IO->sslsock); - if (ret <= 0) - { - IO->accepted = 0; - int error = SSL_get_error(IO->sslsock, ret); - if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) - { - SocketEngine::MarkWritable(newsocket); - return newsocket; - } - - throw SocketException("Unable to accept new SSL connection: " + Anope::string(ERR_error_string(ERR_get_error(), NULL))); - } + newsocket->SetFlag(SF_ACCEPTING); + this->FinishAccept(newsocket); - IO->accepted = 1; return newsocket; } -int SSLSocketIO::Accepted(ClientSocket *cs) +SocketFlag SSLSocketIO::FinishAccept(ClientSocket *cs) { - SSLSocketIO *IO = debug_cast<SSLSocketIO *>(cs->IO); + if (cs->IO == &normalSocketIO) + throw SocketException("Attempting to finish connect uninitialized socket with SSL"); + else if (cs->HasFlag(SF_ACCEPTED)) + return SF_ACCEPTED; + else if (!cs->HasFlag(SF_ACCEPTING)) + throw SocketException("SSLSocketIO::FinishAccept called for a socket not accepted nor accepting?"); - if (IO->accepted == 0) + SSLSocketIO *IO = debug_cast<SSLSocketIO *>(cs->IO); + + int ret = SSL_accept(IO->sslsock); + if (ret <= 0) { - int ret = SSL_accept(IO->sslsock); - if (ret <= 0) + int error = SSL_get_error(IO->sslsock, ret); + if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) { - int error = SSL_get_error(IO->sslsock, ret); - if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) - { - SocketEngine::MarkWritable(cs); - return 0; - } - - return -1; + SocketEngine::MarkWritable(cs); + return SF_ACCEPTING; + } + else + { + cs->OnError(ERR_error_string(ERR_get_error(), NULL)); + cs->SetFlag(SF_DEAD); + cs->UnsetFlag(SF_ACCEPTING); + return SF_DEAD; } - IO->accepted = 1; - return 0; } - - return IO->accepted; + else + { + cs->SetFlag(SF_ACCEPTED); + cs->UnsetFlag(SF_ACCEPTING); + cs->OnAccept(); + return SF_ACCEPTED; + } } void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &target, int port) @@ -279,62 +289,78 @@ void SSLSocketIO::Connect(ConnectionSocket *s, const Anope::string &target, int if (s->IO == &normalSocketIO) throw SocketException("Attempting to connect uninitialized socket with SSL"); - normalSocketIO.Connect(s, target, port); - - SSLSocketIO *IO = debug_cast<SSLSocketIO *>(s->IO); + s->UnsetFlag(SF_CONNECTING); + s->UnsetFlag(SF_CONNECTED); - IO->sslsock = SSL_new(client_ctx); - if (!IO->sslsock) - throw SocketException("Unable to initialize SSL socket"); - - if (!SSL_set_fd(IO->sslsock, s->GetFD())) - throw SocketException("Unable to set SSL fd"); - - int ret = SSL_connect(IO->sslsock); - if (ret <= 0) + s->conaddr.pton(s->IsIPv6() ? AF_INET6 : AF_INET, target, port); + int c = connect(s->GetFD(), &s->conaddr.sa, s->conaddr.size()); + if (c == -1) { - IO->connected = 0; - int error = SSL_get_error(IO->sslsock, ret); - if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) + if (Anope::LastErrorCode() != EINPROGRESS) + { + s->OnError(Anope::LastError()); + s->SetFlag(SF_DEAD); + return; + } + else { SocketEngine::MarkWritable(s); + s->SetFlag(SF_CONNECTING); return; } - - s->ProcessError(); } - - IO->connected = 1; + else + { + s->SetFlag(SF_CONNECTING); + this->FinishConnect(s); + } } -int SSLSocketIO::Connected(ConnectionSocket *s) +SocketFlag SSLSocketIO::FinishConnect(ConnectionSocket *s) { if (s->IO == &normalSocketIO) - throw SocketException("Connected() called for non ssl socket?"); - - int i = SocketIO::Connected(s); - if (i != 1) - return i; + throw SocketException("Attempting to finish connect uninitialized socket with SSL"); + else if (s->HasFlag(SF_CONNECTED)) + return SF_CONNECTED; + else if (!s->HasFlag(SF_CONNECTING)) + throw SocketException("SSLSocketIO::FinishConnect called for a socket not connected nor connecting?"); SSLSocketIO *IO = debug_cast<SSLSocketIO *>(s->IO); - if (IO->connected == 0) + if (IO->sslsock == NULL) { - int ret = SSL_connect(IO->sslsock); - if (ret <= 0) - { - int error = SSL_get_error(IO->sslsock, ret); - if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) - return 0; + IO->sslsock = SSL_new(client_ctx); + if (!IO->sslsock) + throw SocketException("Unable to initialize SSL socket"); - s->ProcessError(); - return -1; - } - IO->connected = 1; - return 0; // poll for next read/write (which will be real), don't assume ones available + if (!SSL_set_fd(IO->sslsock, s->GetFD())) + throw SocketException("Unable to set SSL fd"); } - return IO->connected; + int ret = SSL_connect(IO->sslsock); + if (ret <= 0) + { + int error = SSL_get_error(IO->sslsock, ret); + if (ret == -1 && (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)) + { + SocketEngine::MarkWritable(s); + return SF_CONNECTING; + } + else + { + s->OnError(ERR_error_string(ERR_get_error(), NULL)); + s->UnsetFlag(SF_CONNECTING); + s->SetFlag(SF_DEAD); + return SF_DEAD; + } + } + else + { + s->UnsetFlag(SF_CONNECTING); + s->SetFlag(SF_CONNECTED); + s->OnConnect(); + return SF_CONNECTED; + } } void SSLSocketIO::Destroy() |