Revision: 15780 http://trac.calendarserver.org//changeset/15780 Author: sagen@apple.com Date: 2016-08-02 16:24:05 -0700 (Tue, 02 Aug 2016) Log Message: ----------- Use a connection pool for the LDAP bind calls we use for authentication (in addition to the already existing pool used for other LDAP queries) Modified Paths: -------------- twext/trunk/twext/who/ldap/_service.py twext/trunk/twext/who/ldap/test/test_service.py Modified: twext/trunk/twext/who/ldap/_service.py =================================================================== --- twext/trunk/twext/who/ldap/_service.py 2016-07-27 21:50:59 UTC (rev 15779) +++ twext/trunk/twext/who/ldap/_service.py 2016-08-02 23:24:05 UTC (rev 15780) @@ -203,7 +203,165 @@ }) +class ConnectionPool(object): + log = Logger() + + def __init__(self, poolName, ds, credentials, connectionMax): + self.poolName = poolName + self.ds = ds + self.credentials = credentials + self.connectionQueue = Queue() + self.connectionCreateLock = RLock() + self.connections = [] + self.activeCount = 0 + self.connectionsCreated = 0 + self.connectionMax = connectionMax + + def getConnection(self): + """ + Get a connection from the connection pool. + This will retrieve a connection from the connection pool L{Queue} + object. + If the L{Queue} is empty, it will check to see whether a new connection + can be created (based on the connection limit), and if so create that + and use it. + If no new connections can be created, it will block on the L{Queue} + until an existing, in-use, connection is put back. + """ + try: + connection = self.connectionQueue.get(block=False) + except Empty: + # Note we use a lock here to prevent a race condition in which + # multiple requests for a new connection could succeed even though + # the connection counts starts out one less than the maximum. + # This can happen because self._connect() can take a while. + self.connectionCreateLock.acquire() + if len(self.connections) < self.connectionMax: + connection = self._connect() + self.connections.append(connection) + self.connectionCreateLock.release() + else: + self.connectionCreateLock.release() + self.ds.poolStats["connection-blocked"] += 1 + connection = self.connectionQueue.get() + + + connectionID = "connection-{}".format( + self.connections.index(connection) + ) + + self.ds.poolStats[connectionID] += 1 + self.activeCount = len(self.connections) - self.connectionQueue.qsize() + self.ds.poolStats["connection-active"] = self.activeCount + self.ds.poolStats["connection-max"] = max( + self.ds.poolStats["connection-max"], self.activeCount + ) + + if self.activeCount > self.connectionMax: + self.log.error( + "Active LDAP connections ({active}) exceeds maximum ({max})", + active=self.activeCount, max=self.connectionMax + ) + return connection + + + def returnConnection(self, connection): + """ + A connection is no longer needed - return it to the pool. + """ + self.connectionQueue.put(connection) + self.activeCount = len(self.connections) - self.connectionQueue.qsize() + + + def failedConnection(self, connection): + """ + A connection has failed; remove it from the list of active connections. + A new one will be created if needed. + """ + self.ds.poolStats["connection-errors"] += 1 + self.connections.remove(connection) + self.activeCount = len(self.connections) - self.connectionQueue.qsize() + + + def _connect(self): + """ + Connect to the directory server. + This will always be called in a thread to prevent blocking. + + @returns: The connection object. + @rtype: L{ldap.ldapobject.LDAPObject} + + @raises: L{LDAPConnectionError} if unable to connect. + """ + + self.log.debug("Connecting to LDAP at {log_source.url}") + connection = self._newConnection() + + if self.credentials is not None: + if IUsernamePassword.providedBy(self.credentials): + try: + connection.simple_bind_s( + self.credentials.username, + self.credentials.password, + ) + self.log.debug( + "Bound to LDAP as {credentials.username}", + credentials=self.credentials + ) + except ( + ldap.INVALID_CREDENTIALS, ldap.INVALID_DN_SYNTAX + ) as e: + self.log.error( + "Unable to bind to LDAP as {credentials.username}", + credentials=self.credentials + ) + raise LDAPBindAuthError( + self.credentials.username, e + ) + + else: + raise LDAPConnectionError( + "Unknown credentials type: {0}" + .format(self.credentials) + ) + + return connection + + + def _newConnection(self): + """ + Create a new LDAP connection and initialize and start TLS if required. + This will always be called in a thread to prevent blocking. + + @returns: The connection object. + @rtype: L{ldap.ldapobject.LDAPObject} + + @raises: L{LDAPConnectionError} if unable to connect. + """ + connection = ldap.initialize(self.ds.url) + + # FIXME: Use trace_file option to wire up debug logging when + # Twisted adopts the new logging stuff. + + for option, value in ( + (ldap.OPT_TIMEOUT, self.ds._timeout), + (ldap.OPT_X_TLS_CACERTFILE, self.ds._tlsCACertificateFile), + (ldap.OPT_X_TLS_CACERTDIR, self.ds._tlsCACertificateDirectory), + (ldap.OPT_DEBUG_LEVEL, self.ds._debug), + ): + if value is not None: + connection.set_option(option, value) + + if self.ds._useTLS: + self.log.debug("Starting TLS for {log_source.url}") + connection.start_tls_s() + + self.connectionsCreated += 1 + + return connection + + # # Directory Service # @@ -350,12 +508,14 @@ self.threadpool.adjustPoolsize( max(threadPoolMax, self.threadpool.max) ) - self.connectionMax = connectionMax - self.connectionCreateLock = RLock() - self.connections = [] - self.connectionQueue = Queue() + + # Separate pools for LDAP queries and LDAP binds. Note, they each get + # half of connectionMax. + self.connectionPools = { + "query": ConnectionPool("query", self, credentials, connectionMax / 2), + "auth": ConnectionPool("auth", self, None, connectionMax / 2), + } self.poolStats = collections.defaultdict(int) - self.activeCount = 0 reactor.callWhenRunning(self.start) reactor.addSystemEventTrigger("during", "shutdown", self.stop) @@ -396,164 +556,23 @@ needed. """ - def __init__(self, ds): - self.ds = ds + def __init__(self, ds, poolName): + self.pool = ds.connectionPools[poolName] def __enter__(self): - self.connection = self.ds._getConnection() + self.connection = self.pool.getConnection() return self.connection def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: - self.ds._returnConnection(self.connection) + self.pool.returnConnection(self.connection) return True else: - self.ds._failedConnection(self.connection) + self.pool.failedConnection(self.connection) return False - def _getConnection(self): - """ - Get a connection from the connection pool. - This will retrieve a connection from the connection pool L{Queue} - object. - If the L{Queue} is empty, it will check to see whether a new connection - can be created (based on the connection limit), and if so create that - and use it. - If no new connections can be created, it will block on the L{Queue} - until an existing, in-use, connection is put back. - """ - try: - connection = self.connectionQueue.get(block=False) - except Empty: - # Note we use a lock here to prevent a race condition in which - # multiple requests for a new connection could succeed even though - # the connection counts starts out one less than the maximum. - # This can happen because self._connect() can take a while. - self.connectionCreateLock.acquire() - if len(self.connections) < self.connectionMax: - connection = self._connect() - self.connections.append(connection) - self.connectionCreateLock.release() - else: - self.connectionCreateLock.release() - self.poolStats["connection-blocked"] += 1 - connection = self.connectionQueue.get() - - connectionID = "connection-{}".format( - self.connections.index(connection) - ) - - self.poolStats[connectionID] += 1 - self.activeCount = len(self.connections) - self.connectionQueue.qsize() - self.poolStats["connection-active"] = self.activeCount - self.poolStats["connection-max"] = max( - self.poolStats["connection-max"], self.activeCount - ) - - if self.activeCount > self.connectionMax: - self.log.error( - "Active LDAP connections ({active}) exceeds maximum ({max})", - active=self.activeCount, max=self.connectionMax - ) - return connection - - - def _returnConnection(self, connection): - """ - A connection is no longer needed - return it to the pool. - """ - self.connectionQueue.put(connection) - self.activeCount = len(self.connections) - self.connectionQueue.qsize() - - - def _failedConnection(self, connection): - """ - A connection has failed; remove it from the list of active connections. - A new one will be created if needed. - """ - self.poolStats["connection-errors"] += 1 - self.connections.remove(connection) - self.activeCount = len(self.connections) - self.connectionQueue.qsize() - - - def _connect(self): - """ - Connect to the directory server. - This will always be called in a thread to prevent blocking. - - @returns: The connection object. - @rtype: L{ldap.ldapobject.LDAPObject} - - @raises: L{LDAPConnectionError} if unable to connect. - """ - - self.log.debug("Connecting to LDAP at {log_source.url}") - connection = self._newConnection() - - if self._credentials is not None: - if IUsernamePassword.providedBy(self._credentials): - try: - connection.simple_bind_s( - self._credentials.username, - self._credentials.password, - ) - self.log.debug( - "Bound to LDAP as {credentials.username}", - credentials=self._credentials - ) - except ( - ldap.INVALID_CREDENTIALS, ldap.INVALID_DN_SYNTAX - ) as e: - self.log.error( - "Unable to bind to LDAP as {credentials.username}", - credentials=self._credentials - ) - raise LDAPBindAuthError( - self._credentials.username, e - ) - - else: - raise LDAPConnectionError( - "Unknown credentials type: {0}" - .format(self._credentials) - ) - - return connection - - - def _newConnection(self): - """ - Create a new LDAP connection and initialize and start TLS if required. - This will always be called in a thread to prevent blocking. - - @returns: The connection object. - @rtype: L{ldap.ldapobject.LDAPObject} - - @raises: L{LDAPConnectionError} if unable to connect. - """ - connection = ldap.initialize(self.url) - - # FIXME: Use trace_file option to wire up debug logging when - # Twisted adopts the new logging stuff. - - for option, value in ( - (ldap.OPT_TIMEOUT, self._timeout), - (ldap.OPT_X_TLS_CACERTFILE, self._tlsCACertificateFile), - (ldap.OPT_X_TLS_CACERTDIR, self._tlsCACertificateDirectory), - (ldap.OPT_DEBUG_LEVEL, self._debug), - ): - if value is not None: - connection.set_option(option, value) - - if self._useTLS: - self.log.debug("Starting TLS for {log_source.url}") - connection.start_tls_s() - - return connection - - def _authenticateUsernamePassword(self, dn, password): """ Open a secondary connection to the LDAP server and try binding to it @@ -587,30 +606,27 @@ @raises: L{LDAPConnectionError} if unable to connect. """ self.log.debug("Authenticating {dn}", dn=dn) - connection = self._newConnection() + with DirectoryService.Connection(self, "auth") as connection: + try: + connection.simple_bind_s(dn, password) + self.log.debug("Authenticated {dn}", dn=dn) + return True + except ( + ldap.INAPPROPRIATE_AUTH, + ldap.INVALID_CREDENTIALS, + ldap.INVALID_DN_SYNTAX, + ): + self.log.debug("Unable to authenticate {dn}", dn=dn) + return False + except ldap.CONSTRAINT_VIOLATION: + self.log.info("Account locked {dn}", dn=dn) + return False + except Exception as e: + self.log.error("Unexpected error {error} trying to authenticate {dn}", error=str(e), dn=dn) + return False - try: - connection.simple_bind_s(dn, password) - self.log.debug("Authenticated {dn}", dn=dn) - return True - except ( - ldap.INAPPROPRIATE_AUTH, - ldap.INVALID_CREDENTIALS, - ldap.INVALID_DN_SYNTAX, - ): - self.log.debug("Unable to authenticate {dn}", dn=dn) - return False - except ldap.CONSTRAINT_VIOLATION: - self.log.info("Account locked {dn}", dn=dn) - return False - except Exception as e: - self.log.error("Unexpected error {error} trying to authenticate {dn}", error=str(e), dn=dn) - return False - finally: - connection.unbind() - def _recordsFromQueryString( self, queryString, recordTypes=None, limitResults=None, timeoutSeconds=None @@ -654,7 +670,7 @@ try: - with DirectoryService.Connection(self) as connection: + with DirectoryService.Connection(self, "query") as connection: for recordType in recordTypes: @@ -836,7 +852,7 @@ try: - with DirectoryService.Connection(self) as connection: + with DirectoryService.Connection(self, "query") as connection: self.log.debug("Performing LDAP DN query: {dn}", dn=dn) Modified: twext/trunk/twext/who/ldap/test/test_service.py =================================================================== --- twext/trunk/twext/who/ldap/test/test_service.py 2016-07-27 21:50:59 UTC (rev 15779) +++ twext/trunk/twext/who/ldap/test/test_service.py 2016-08-02 23:24:05 UTC (rev 15780) @@ -317,28 +317,27 @@ class DirectoryServiceConnectionTestMixIn(object): - @inlineCallbacks + def test_connect_defaults(self): """ Connect with default arguments. """ service = self.service() - connection = yield service._connect() + with TestService.Connection(service, "query") as connection: - self.assertEquals(connection.methods_called(), ["initialize"]) + self.assertEquals(connection.methods_called(), ["initialize"]) - for option in ( - ldap.OPT_TIMEOUT, - ldap.OPT_X_TLS_CACERTFILE, - ldap.OPT_X_TLS_CACERTDIR, - ldap.OPT_DEBUG_LEVEL, - ): - self.assertRaises(KeyError, connection.get_option, option) + for option in ( + ldap.OPT_TIMEOUT, + ldap.OPT_X_TLS_CACERTFILE, + ldap.OPT_X_TLS_CACERTDIR, + ldap.OPT_DEBUG_LEVEL, + ): + self.assertRaises(KeyError, connection.get_option, option) - self.assertFalse(connection.tls_enabled) + self.assertFalse(connection.tls_enabled) - @inlineCallbacks def test_connect_withUsernamePassword_invalid(self): """ Connect with UsernamePassword credentials. @@ -349,14 +348,14 @@ ) service = self.service(credentials=credentials) try: - yield service._connect() + with TestService.Connection(service, "query"): + pass except LDAPBindAuthError: pass else: self.fail("Should have raised LDAPBindAuthError") - @inlineCallbacks def test_connect_withUsernamePassword_valid(self): """ Connect with UsernamePassword credentials. @@ -366,15 +365,14 @@ u"zehcnasw" ) service = self.service(credentials=credentials) - connection = yield service._connect() - self.assertEquals( - connection.methods_called(), - ["initialize", "simple_bind_s"] - ) + with TestService.Connection(service, "query") as connection: + self.assertEquals( + connection.methods_called(), + ["initialize", "simple_bind_s"] + ) - @inlineCallbacks def test_connect_withOptions(self): """ Connect with default arguments. @@ -385,43 +383,42 @@ tlsCACertificateDirectory=FilePath("/path/to/certdir"), _debug=True, ) - connection = yield service._connect() + with TestService.Connection(service, "query") as connection: - self.assertEquals( - connection.methods_called(), - [ - "initialize", - "set_option", "set_option", "set_option", "set_option", - ] - ) + self.assertEquals( + connection.methods_called(), + [ + "initialize", + "set_option", "set_option", "set_option", "set_option", + ] + ) - opt = lambda k: connection.get_option(k) + opt = lambda k: connection.get_option(k) - self.assertEquals(opt(ldap.OPT_TIMEOUT), 18) - self.assertEquals(opt(ldap.OPT_X_TLS_CACERTFILE), "/path/to/cert") - self.assertEquals(opt(ldap.OPT_X_TLS_CACERTDIR), "/path/to/certdir") - self.assertEquals(opt(ldap.OPT_DEBUG_LEVEL), 255) + self.assertEquals(opt(ldap.OPT_TIMEOUT), 18) + self.assertEquals(opt(ldap.OPT_X_TLS_CACERTFILE), "/path/to/cert") + self.assertEquals(opt(ldap.OPT_X_TLS_CACERTDIR), "/path/to/certdir") + self.assertEquals(opt(ldap.OPT_DEBUG_LEVEL), 255) - # Tested in test_connect_defaults, but test again here since we're - # setting SSL options and we want to be sure they don't somehow enable - # SSL implicitly. - self.assertFalse(connection.tls_enabled) + # Tested in test_connect_defaults, but test again here since we're + # setting SSL options and we want to be sure they don't somehow enable + # SSL implicitly. + self.assertFalse(connection.tls_enabled) - @inlineCallbacks def test_connect_withTLS(self): """ Connect with TLS enabled. """ service = self.service(useTLS=True) - connection = yield service._connect() - self.assertEquals( - connection.methods_called(), - ["initialize", "start_tls_s"] - ) + with TestService.Connection(service, "query") as connection: + self.assertEquals( + connection.methods_called(), + ["initialize", "start_tls_s"] + ) - self.assertTrue(connection.tls_enabled) + self.assertTrue(connection.tls_enabled) @@ -457,11 +454,11 @@ # still have a connection in the pool service._recordsFromQueryString_inThread("(this=that)") self.assertEquals(service._retryNumber, 0) - self.assertEquals(len(service.connections), 1) + self.assertEquals(len(service.connectionPools["query"].connections), 1) service._recordWithDN_inThread("cn=test") self.assertEquals(service._retryNumber, 0) - self.assertEquals(len(service.connections), 1) + self.assertEquals(len(service.connectionPools["query"].connections), 1) # Force a search to raise SERVER_DOWN def raiseServerDown(*args, **kwds): @@ -479,7 +476,7 @@ self.fail("Should have raised LDAPQueryError") # Verify the connections are all closed - self.assertEquals(len(service.connections), 0) + self.assertEquals(len(service.connectionPools["query"].connections), 0) # Now try recordWithDN try: @@ -491,10 +488,67 @@ self.fail("Should have raised LDAPQueryError") # Verify the connections are all closed - self.assertEquals(len(service.connections), 0) + self.assertEquals(len(service.connectionPools["query"].connections), 0) + @inlineCallbacks + def test_auth_pool(self): + """ + Verify acquiring connections from the LDAP connection pool will block + when connectionMax is reached, + and that + """ + service = self.service(connectionMax=4) + pool = service.connectionPools["auth"] + + self.assertEquals(0, len(pool.connections)) + self.assertEquals(0, pool.connectionsCreated) + + # Ask for a connection and check the counts + with TestService.Connection(service, "auth"): + self.assertEquals(1, len(pool.connections)) + + self.assertEquals(1, len(pool.connections)) + self.assertEquals(1, pool.connectionsCreated) + + # Ask for two connections and check the counts + with TestService.Connection(service, "auth"): + self.assertEquals(1, len(pool.connections)) + self.assertEquals(1, pool.connectionsCreated) + with TestService.Connection(service, "auth"): + self.assertEquals(2, len(pool.connections)) + self.assertEquals(2, pool.connectionsCreated) + + # Ask for three connections (one more than connectionMax/2) and + # the third will actually block until the returnConnection( ) call + with TestService.Connection(service, "auth"): + self.assertEquals(2, len(pool.connections)) + self.assertEquals(2, pool.connectionsCreated) + with TestService.Connection(service, "auth") as connection: + self.assertEquals(2, len(pool.connections)) + self.assertEquals(2, pool.connectionsCreated) + + # schedule a connection to be returned in 1 second + from twisted.internet import reactor + reactor.callLater(1, pool.returnConnection, connection) + + # For the third connection, I'm using this method so it gets + # requested in a thread, otherwise we'd hang. + yield service._authenticateUsernamePassword( + u"uid=wsanchez,cn=user,{0}".format(self.baseDN), + u"zehcnasw" + ) + + # Proof we bumped up against connection-max: + self.assertEquals(1, service.poolStats["connection-blocked"]) + + self.assertEquals(2, len(pool.connections)) + self.assertEquals(2, pool.connectionsCreated) + self.assertEquals(2, service.poolStats["connection-max"]) + + + class ExtraFiltersTest(BaseTestCase, unittest.TestCase): def test_extraFilters(self):
participants (1)
-
source_changes@macosforge.org