[CalendarServer-changes] [13948] twext/trunk/twext/who

source_changes at macosforge.org source_changes at macosforge.org
Wed Sep 10 20:32:22 PDT 2014


Revision: 13948
          http://trac.calendarserver.org//changeset/13948
Author:   sagen at apple.com
Date:     2014-09-10 20:32:22 -0700 (Wed, 10 Sep 2014)
Log Message:
-----------
Adds limitResults and timeoutSeconds to twext.who API; greatly speeds up LDAP group expansion

Modified Paths:
--------------
    twext/trunk/twext/who/aggregate.py
    twext/trunk/twext/who/directory.py
    twext/trunk/twext/who/idirectory.py
    twext/trunk/twext/who/index.py
    twext/trunk/twext/who/ldap/_service.py
    twext/trunk/twext/who/ldap/test/test_service.py
    twext/trunk/twext/who/opendirectory/_service.py
    twext/trunk/twext/who/test/test_directory.py
    twext/trunk/twext/who/test/test_index.py
    twext/trunk/twext/who/test/test_xml.py

Modified: twext/trunk/twext/who/aggregate.py
===================================================================
--- twext/trunk/twext/who/aggregate.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/aggregate.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -109,10 +109,14 @@
         return d
 
 
-    def recordsFromExpression(self, expression, recordTypes=None, records=None):
+    def recordsFromExpression(
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         return self._gatherFromSubServices(
             "recordsFromExpression", expression, recordTypes=recordTypes,
-            records=None
+            records=None,
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
         )
 
 
@@ -123,41 +127,58 @@
     # want to call their implementations, not bypass them.
 
 
-    def recordsWithFieldValue(self, fieldName, value):
+    def recordsWithFieldValue(
+        self, fieldName, value, limitResults=None, timeoutSeconds=None
+    ):
         return self._gatherFromSubServices(
-            "recordsWithFieldValue", fieldName, value
+            "recordsWithFieldValue", fieldName, value,
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
         )
 
 
-    def recordWithUID(self, uid):
-        return self._oneFromSubServices("recordWithUID", uid)
+    def recordWithUID(self, uid, timeoutSeconds=None):
+        return self._oneFromSubServices(
+            "recordWithUID", uid, timeoutSeconds=timeoutSeconds
+        )
 
 
-    def recordWithGUID(self, guid):
-        return self._oneFromSubServices("recordWithGUID", guid)
+    def recordWithGUID(self, guid, timeoutSeconds=None):
+        return self._oneFromSubServices(
+            "recordWithGUID", guid, timeoutSeconds=timeoutSeconds
+        )
 
 
-    def recordsWithRecordType(self, recordType):
+    def recordsWithRecordType(
+        self, recordType, limitResults=None, timeoutSeconds=None
+    ):
         # Since we know the recordType, we can go directly to the appropriate
         # service.
         for service in self.services:
             if recordType in service.recordTypes():
-                return service.recordsWithRecordType(recordType)
+                return service.recordsWithRecordType(
+                    recordType,
+                    limitResults=limitResults, timeoutSeconds=timeoutSeconds
+                )
         return succeed(())
 
 
-    def recordWithShortName(self, recordType, shortName):
+    def recordWithShortName(self, recordType, shortName, timeoutSeconds=None):
         # Since we know the recordType, we can go directly to the appropriate
         # service.
         for service in self.services:
             if recordType in service.recordTypes():
-                return service.recordWithShortName(recordType, shortName)
+                return service.recordWithShortName(
+                    recordType, shortName, timeoutSeconds=timeoutSeconds
+                )
         return succeed(None)
 
 
-    def recordsWithEmailAddress(self, emailAddress):
+    def recordsWithEmailAddress(
+        self, emailAddress, limitResults=None, timeoutSeconds=None
+    ):
         return self._gatherFromSubServices(
-            "recordsWithEmailAddress", emailAddress
+            "recordsWithEmailAddress", emailAddress,
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
         )
 
 

Modified: twext/trunk/twext/who/directory.py
===================================================================
--- twext/trunk/twext/who/directory.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/directory.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -123,7 +123,8 @@
 
 
     def recordsFromNonCompoundExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         """
         Finds records matching a non-compound expression.
@@ -172,7 +173,8 @@
 
     @inlineCallbacks
     def recordsFromCompoundExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         """
         Finds records matching a compound expression.
@@ -246,7 +248,10 @@
         returnValue(results)
 
 
-    def recordsFromExpression(self, expression, recordTypes=None, records=None):
+    def recordsFromExpression(
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         """
         @note: This interface is the same as
             L{IDirectoryService.recordsFromExpression}, except for the
@@ -254,41 +259,57 @@
         """
         if isinstance(expression, CompoundExpression):
             return self.recordsFromCompoundExpression(
-                expression, recordTypes=recordTypes
+                expression, recordTypes=recordTypes,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
         else:
             return self.recordsFromNonCompoundExpression(
-                expression, recordTypes=recordTypes
+                expression, recordTypes=recordTypes,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
 
 
-    def recordsWithFieldValue(self, fieldName, value, recordTypes=None):
+    def recordsWithFieldValue(
+        self, fieldName, value, recordTypes=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         return self.recordsFromExpression(
             MatchExpression(fieldName, value),
-            recordTypes=recordTypes
+            recordTypes=recordTypes,
+            limitResults=limitResults,
+            timeoutSeconds=timeoutSeconds
         )
 
 
     @inlineCallbacks
-    def recordWithUID(self, uid):
+    def recordWithUID(self, uid, timeoutSeconds=None):
         returnValue(uniqueResult(
-            (yield self.recordsWithFieldValue(FieldName.uid, uid))
+            (yield self.recordsWithFieldValue(
+                FieldName.uid, uid, timeoutSeconds=timeoutSeconds
+            ))
         ))
 
 
     @inlineCallbacks
-    def recordWithGUID(self, guid):
+    def recordWithGUID(self, guid, timeoutSeconds=None):
         returnValue(uniqueResult(
-            (yield self.recordsWithFieldValue(FieldName.guid, guid))
+            (yield self.recordsWithFieldValue(
+                FieldName.guid, guid, timeoutSeconds=timeoutSeconds
+            ))
         ))
 
 
-    def recordsWithRecordType(self, recordType):
-        return self.recordsWithFieldValue(FieldName.recordType, recordType)
+    def recordsWithRecordType(
+        self, recordType, limitResults=None, timeoutSeconds=None
+    ):
+        return self.recordsWithFieldValue(
+            FieldName.recordType, recordType,
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
+        )
 
 
     @inlineCallbacks
-    def recordWithShortName(self, recordType, shortName):
+    def recordWithShortName(self, recordType, shortName, timeoutSeconds=None):
         returnValue(
             uniqueResult(
                 (
@@ -296,17 +317,22 @@
                         MatchExpression(
                             FieldName.shortNames, shortName
                         ),
-                        recordTypes=[recordType]
+                        recordTypes=[recordType],
+                        timeoutSeconds=timeoutSeconds
                     )
                 )
             )
         )
 
 
-    def recordsWithEmailAddress(self, emailAddress):
+    def recordsWithEmailAddress(
+        self, emailAddress, limitResults=None, timeoutSeconds=None
+    ):
         return self.recordsWithFieldValue(
             FieldName.emailAddresses,
             emailAddress,
+            limitResults=limitResults,
+            timeoutSeconds=timeoutSeconds
         )
 
 

Modified: twext/trunk/twext/who/idirectory.py
===================================================================
--- twext/trunk/twext/who/idirectory.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/idirectory.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -277,7 +277,9 @@
         @rtype: iterable of L{NamedConstant}s
         """
 
-    def recordsFromExpression(expression, recordTypes=None):
+    def recordsFromExpression(
+        expression, recordTypes=None, limitResults=None, timeoutSeconds=None
+    ):
         """
         Find records matching an expression.
 
@@ -288,6 +290,13 @@
         @type recordTypes: an iterable of L{NamedConstant}, or None for no
             filtering
 
+        @param limitResults: how many records to limit the results to
+        @type limitResults: an C{integer} or None if no limit desired
+
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching records.
         @rtype: deferred iterable of L{IDirectoryRecord}s
 
@@ -295,7 +304,9 @@
             supported by this directory service.
         """
 
-    def recordsWithFieldValue(fieldName, value):
+    def recordsWithFieldValue(
+        fieldName, value, limitResults=None, timeoutSeconds=None
+    ):
         """
         Find records that have the given field name with the given
         value.
@@ -306,44 +317,68 @@
         @param value: a value to match
         @type value: L{object}
 
+        @param limitResults: how many records to limit the results to
+        @type limitResults: an C{integer} or None if no limit desired
+
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching records.
         @rtype: deferred iterable of L{IDirectoryRecord}s
         """
 
-    def recordWithUID(uid):
+    def recordWithUID(uid, timeoutSeconds=None):
         """
         Find the record that has the given UID.
 
         @param uid: a UID
         @type uid: L{unicode}
 
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching record or C{None} if there is no match.
         @rtype: deferred L{IDirectoryRecord}s or C{None}
         """
 
-    def recordWithGUID(guid):
+    def recordWithGUID(guid, timeoutSeconds=None):
         """
         Find the record that has the given GUID.
 
         @param guid: a GUID
         @type guid: L{UUID}
 
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching record or C{None} if there is no match.
         @rtype: deferred L{IDirectoryRecord}s or C{None}
         """
 
-    def recordsWithRecordType(recordType):
+    def recordsWithRecordType(
+        recordType, limitResults=None, timeoutSeconds=None
+    ):
         """
         Find the records that have the given record type.
 
         @param recordType: a record type
         @type recordType: L{NamedConstant}
 
+        @param limitResults: how many records to limit the results to
+        @type limitResults: an C{integer} or None if no limit desired
+
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching records.
         @rtype: deferred iterable of L{IDirectoryRecord}s
         """
 
-    def recordWithShortName(recordType, shortName):
+    def recordWithShortName(recordType, shortName, timeoutSeconds=None):
         """
         Find the record that has the given record type and short name.
 
@@ -353,17 +388,27 @@
         @param shortName: a short name
         @type shortName: L{unicode}
 
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching record or C{None} if there is no match.
         @rtype: deferred L{IDirectoryRecord}s or C{None}
         """
 
-    def recordsWithEmailAddress(emailAddress):
+    def recordsWithEmailAddress(
+        emailAddress, limitResults=None, timeoutSeconds=None
+    ):
         """
         Find the records that have the given email address.
 
         @param emailAddress: an email address
         @type emailAddress: L{unicode}
 
+        @param timeoutSeconds: how long (in seconds) to let a directory service
+            request to run before giving up
+        @type timeoutSeconds: an C{integer} or None if no limit desired
+
         @return: The matching records.
         @rtype: deferred iterable of L{IDirectoryRecord}s
         """

Modified: twext/trunk/twext/who/index.py
===================================================================
--- twext/trunk/twext/who/index.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/index.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -227,7 +227,8 @@
 
 
     def indexedRecordsFromMatchExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         """
         Finds records in the internal indexes matching a single expression.
@@ -293,11 +294,15 @@
                 if record.recordType not in recordTypes:
                     matchingRecords.remove(record)
 
+        if limitResults is not None:
+            matchingRecords = set(list(matchingRecords)[:limitResults])
+
         return succeed(matchingRecords)
 
 
     def unIndexedRecordsFromMatchExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         """
         Finds records not in the internal indexes matching a single expression.
@@ -335,11 +340,15 @@
                 if expression.match(fieldValue):
                     result.add(record)
 
+        if limitResults is not None:
+            result = set(list(result)[:limitResults])
+
         return succeed(result)
 
 
     def recordsFromNonCompoundExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         """
         This implementation can handle L{MatchExpression} expressions; other
@@ -348,15 +357,18 @@
         if isinstance(expression, MatchExpression):
             if expression.fieldName in self.indexedFields:
                 return self.indexedRecordsFromMatchExpression(
-                    expression, recordTypes=recordTypes, records=records
+                    expression, recordTypes=recordTypes, records=records,
+                    limitResults=limitResults, timeoutSeconds=timeoutSeconds
                 )
             else:
                 return self.unIndexedRecordsFromMatchExpression(
-                    expression, recordTypes=recordTypes, records=records
+                    expression, recordTypes=recordTypes, records=records,
+                    limitResults=limitResults, timeoutSeconds=timeoutSeconds
                 )
         else:
             return BaseDirectoryService.recordsFromNonCompoundExpression(
-                self, expression, recordTypes=recordTypes, records=records
+                self, expression, recordTypes=recordTypes, records=records,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
 
 

Modified: twext/trunk/twext/who/ldap/_service.py
===================================================================
--- twext/trunk/twext/who/ldap/_service.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/ldap/_service.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -27,6 +27,7 @@
 
 import collections
 import ldap
+import ldap.async
 
 from twisted.python.constants import Names, NamedConstant
 from twisted.internet.defer import succeed, inlineCallbacks, returnValue
@@ -47,7 +48,10 @@
     DirectoryService as BaseDirectoryService,
     DirectoryRecord as BaseDirectoryRecord,
 )
-from ..expression import MatchExpression, ExistsExpression, BooleanExpression
+from ..expression import (
+    MatchExpression, ExistsExpression, BooleanExpression,
+    CompoundExpression, Operand, MatchType
+)
 from ..util import ConstantsContainer
 from ._constants import LDAPAttribute, LDAPObjectClass
 from ._util import (
@@ -554,16 +558,24 @@
             connection = None
 
 
-    def _recordsFromQueryString(self, queryString, recordTypes=None):
+    def _recordsFromQueryString(
+        self, queryString, recordTypes=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         return deferToThreadPool(
             reactor, self.threadpool,
             self._recordsFromQueryString_inThread,
             queryString,
             recordTypes,
+            limitResults=limitResults,
+            timeoutSeconds=timeoutSeconds
         )
 
 
-    def _recordsFromQueryString_inThread(self, queryString, recordTypes=None):
+    def _recordsFromQueryString_inThread(
+        self, queryString, recordTypes=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         """
         This method is always called in a thread.
         """
@@ -571,10 +583,16 @@
 
         with DirectoryService.Connection(self) as connection:
 
+
             if recordTypes is None:
                 recordTypes = self.recordTypes()
 
             for recordType in recordTypes:
+
+                if limitResults is not None:
+                    if limitResults < 1:
+                        break
+
                 try:
                     rdn = self._recordTypeSchemas[recordType].relativeDN
                 except KeyError:
@@ -586,23 +604,35 @@
                     ldap.dn.str2dn(self._baseDN.lower())
                 )
                 self.log.debug(
-                    "Performing LDAP query: {rdn} {query} {recordType}",
+                    "Performing LDAP query: {rdn} {query} {recordType}{limit}{timeout}",
                     rdn=rdn,
                     query=queryString,
-                    recordType=recordType
+                    recordType=recordType,
+                    limit=" limit={}".format(limitResults) if limitResults else "",
+                    timeout=" timeout={}".format(timeoutSeconds) if timeoutSeconds else "",
                 )
                 try:
-                    reply = connection.search_s(
+                    s = ldap.async.List(connection)
+                    s.startSearch(
                         ldap.dn.dn2str(rdn),
                         ldap.SCOPE_SUBTREE,
                         queryString,
-                        attrlist=self._attributesToFetch
+                        attrList=self._attributesToFetch,
+                        timeout=timeoutSeconds if timeoutSeconds else -1,
+                        sizelimit=limitResults if limitResults else 0
                     )
+                    s.processResults()
 
+                except ldap.SIZELIMIT_EXCEEDED, e:
+                    self.log.debug("LDAP result limit exceeded: {}".format(limitResults,))
+
+                except ldap.TIMELIMIT_EXCEEDED, e:
+                    self.log.warn("LDAP timeout exceeded: {} seconds".format(timeoutSeconds,))
+
                 except ldap.FILTER_ERROR as e:
                     self.log.error(
-                        "Unable to perform query {0!r}: {1}"
-                        .format(queryString, e)
+                        "Unable to perform query {query!r}: {err}",
+                        query=queryString, err=e
                     )
                     raise LDAPQueryError("Unable to perform query", e)
 
@@ -610,10 +640,49 @@
                     # self.log.warn("RDN {rdn} does not exist, skipping", rdn=rdn)
                     continue
 
-                records.extend(
-                    self._recordsFromReply(reply, recordType=recordType)
+                except ldap.INVALID_SYNTAX, e:
+                    self.log.error(
+                        "LDAP invalid syntax {query!r}: {err}",
+                        query=queryString, err=e
+                    )
+                    continue
+
+                except ldap.SERVER_DOWN:
+                    self.log.error(
+                        "LDAP server unavailable"
+                    )
+                    continue
+
+                except Exception, e:
+                    self.log.error(
+                        "LDAP error {query!r}: {err}",
+                        query=queryString, err=e
+                    )
+                    continue
+
+                reply = [resultItem for resultType, resultItem in s.allResults]
+
+                newRecords = self._recordsFromReply(reply, recordType=recordType)
+
+                self.log.debug(
+                    "Records from LDAP query ({rdn} {query} {recordType}): {count}",
+                    rdn=rdn,
+                    query=queryString,
+                    recordType=recordType,
+                    count=len(newRecords)
                 )
 
+                if limitResults is not None:
+                    limitResults = limitResults - len(newRecords)
+
+                records.extend(newRecords)
+
+        self.log.debug(
+            "LDAP result count ({query}): {count}",
+            query=queryString,
+            count=len(records)
+        )
+
         return records
 
 
@@ -758,7 +827,8 @@
 
 
     def recordsFromNonCompoundExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None, limitResults=None,
+        timeoutSeconds=None
     ):
         if isinstance(expression, MatchExpression):
             queryString = ldapQueryStringFromMatchExpression(
@@ -766,7 +836,8 @@
                 self._fieldNameToAttributesMap, self._recordTypeSchemas
             )
             return self._recordsFromQueryString(
-                queryString, recordTypes=recordTypes
+                queryString, recordTypes=recordTypes,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
 
         elif isinstance(expression, ExistsExpression):
@@ -775,7 +846,8 @@
                 self._fieldNameToAttributesMap, self._recordTypeSchemas
             )
             return self._recordsFromQueryString(
-                queryString, recordTypes=recordTypes
+                queryString, recordTypes=recordTypes,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
 
         elif isinstance(expression, BooleanExpression):
@@ -784,16 +856,19 @@
                 self._fieldNameToAttributesMap, self._recordTypeSchemas
             )
             return self._recordsFromQueryString(
-                queryString, recordTypes=recordTypes
+                queryString, recordTypes=recordTypes,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
 
         return BaseDirectoryService.recordsFromNonCompoundExpression(
-            self, expression, records=records
+            self, expression, records=records, limitResults=limitResults,
+            timeoutSeconds=timeoutSeconds
         )
 
 
     def recordsFromCompoundExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         if not expression.expressions:
             return succeed(())
@@ -803,17 +878,21 @@
             self._fieldNameToAttributesMap, self._recordTypeSchemas
         )
         return self._recordsFromQueryString(
-            queryString, recordTypes=recordTypes
+            queryString, recordTypes=recordTypes,
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
         )
 
 
-    def recordsWithRecordType(self, recordType):
+    def recordsWithRecordType(
+        self, recordType, limitResults=None, timeoutSeconds=None
+    ):
         queryString = ldapQueryStringFromExistsExpression(
             ExistsExpression(self.fieldName.uid),
             self._fieldNameToAttributesMap, self._recordTypeSchemas
         )
         return self._recordsFromQueryString(
-            queryString, recordTypes=[recordType]
+            queryString, recordTypes=[recordType],
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
         )
 
 
@@ -839,12 +918,66 @@
     @inlineCallbacks
     def members(self):
 
+        members = set()
+
         if self.recordType != self.service.recordType.group:
             returnValue(())
 
-        members = set()
-        for dn in getattr(self, "memberDNs", []):
-            record = yield self.service._recordWithDN(dn)
+        # Scan through the memberDNs, grouping them by record type (which we
+        # deduce by their RDN).  If we have a fieldname that corresponds to
+        # the most specific slice of the DN, we can bundle that into a
+        # single CompoundExpression to fault in all the DNs belonging to the
+        # same base RDN, reducing the number of requests from 1-per-member to
+        # 1-per-record-type.  Any memberDNs we can't group in this way are
+        # simply faulted in by DN at the end.
+
+        fieldValuesByRecordType = {}
+        # dictionary key = recordType, value = tuple(fieldName, value)
+
+        faultByDN = []
+        # the DNs we need to fault in individually
+
+        for dnStr in getattr(self, "memberDNs", []):
+            try:
+                recordType = recordTypeForDN(
+                    self.service._baseDN, self.service._recordTypeSchemas, dnStr
+                )
+                dn = ldap.dn.str2dn(dnStr.lower())
+                attrName, value, ignored = dn[0][0]
+                fieldName = self.service._attributeToFieldNameMap[attrName][0]
+                fieldValuesByRecordType.setdefault(recordType, []).append((fieldName, value))
+                continue
+
+            except:
+                # For whatever reason we can't group this DN in with the others
+                # so we'll add it to faultByDN just below
+                pass
+
+            # have to fault in by dn
+            faultByDN.append(dnStr)
+
+        for recordType, fieldValue in fieldValuesByRecordType.iteritems():
+            if fieldValue:
+                matchExpressions = []
+                for fieldName, value in fieldValue:
+                    matchExpressions.append(
+                        MatchExpression(
+                            fieldName,
+                            value.decode("utf-8"),
+                            matchType=MatchType.equals
+                        )
+                    )
+            expression = CompoundExpression(
+                matchExpressions,
+                Operand.OR
+            )
+            for record in (yield self.service.recordsFromCompoundExpression(
+                expression, recordTypes=[recordType]
+            )):
+                members.add(record)
+
+        for dnStr in faultByDN:
+            record = yield self.service._recordWithDN(dnStr)
             members.add(record)
 
         returnValue(members)
@@ -863,6 +996,7 @@
         return self.service._authenticateUsernamePassword(self.dn, password)
 
 
+
 def normalizeDNstr(dnStr):
     """
     Convert to lowercase and remove extra whitespace

Modified: twext/trunk/twext/who/ldap/test/test_service.py
===================================================================
--- twext/trunk/twext/who/ldap/test/test_service.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/ldap/test/test_service.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -24,13 +24,21 @@
 
 import ldap
 
-try:
-    from mockldap import MockLdap
-    from mockldap.filter import (
-        Test as MockLDAPFilterTest,
-        UnsupportedOp as MockLDAPUnsupportedOp,
-    )
-except ImportError:
+
+# FIXME:
+MOCKLDAP_SUPPORTS_LDAP_ASYNC = False
+
+
+if MOCKLDAP_SUPPORTS_LDAP_ASYNC:
+    try:
+        from mockldap import MockLdap
+        from mockldap.filter import (
+            Test as MockLDAPFilterTest,
+            UnsupportedOp as MockLDAPUnsupportedOp,
+        )
+    except ImportError:
+        MockLdap = None
+else:
     MockLdap = None
 
 from twext.python.types import MappingProxyType

Modified: twext/trunk/twext/who/opendirectory/_service.py
===================================================================
--- twext/trunk/twext/who/opendirectory/_service.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/opendirectory/_service.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -392,7 +392,9 @@
         )
 
 
-    def _queryFromCompoundExpression(self, expression, recordTypes=None, local=False):
+    def _queryFromCompoundExpression(
+        self, expression, recordTypes=None, local=False, limitResults=None
+    ):
         """
         Form an OpenDirectory query from a compound expression.
 
@@ -442,7 +444,10 @@
             matchType = ODMatchType.any.value
 
         attributes = [a.value for a in ODAttribute.iterconstants()]
-        maxResults = 0
+        if limitResults is None:
+            maxResults = 0
+        else:
+            maxResults = limitResults
 
         query, error = ODQuery.queryWithNode_forRecordTypes_attribute_matchType_queryValues_returnAttributes_maximumResults_error_(
             node,
@@ -468,7 +473,7 @@
 
 
     def _queryFromMatchExpression(
-        self, expression, recordTypes=None, local=False
+        self, expression, recordTypes=None, local=False, limitResults=None
     ):
         """
         Form an OpenDirectory query from a match expression.
@@ -503,8 +508,12 @@
             caseInsensitive = 0x0
 
         fetchAttributes = [a.value for a in ODAttribute.iterconstants()]
-        maxResults = 0
 
+        if limitResults is None:
+            maxResults = 0
+        else:
+            maxResults = limitResults
+
         # For OpenDirectory, use guid for uid:
         if expression.fieldName is self.fieldName.uid:
             expression.fieldName = self.fieldName.guid
@@ -644,13 +653,17 @@
 
 
     @inlineCallbacks
-    def _recordsFromQuery(self, query):
+    def _recordsFromQuery(self, query, timeoutSeconds=None):
         """
         Executes a query and generates directory records from it.
 
         @param query: A query.
         @type query: L{ODQuery}
 
+        @param timeoutSeconds: number of seconds after which the request
+            should timeout (currently unused)
+        @type timeoutSeconds: C{integer}
+
         @return: The records produced by executing the query.
         @rtype: list of L{DirectoryRecord}
         """
@@ -707,11 +720,17 @@
         returnValue(result)
 
 
-    def recordsFromNonCompoundExpression(self, expression, recordTypes=None, records=None):
+    def recordsFromNonCompoundExpression(
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         if isinstance(expression, MatchExpression):
             try:
-                query = self._queryFromMatchExpression(expression, recordTypes=recordTypes)
-                return self._recordsFromQuery(query)
+                query = self._queryFromMatchExpression(
+                    expression, recordTypes=recordTypes, limitResults=limitResults)
+                return self._recordsFromQuery(
+                    query, timeoutSeconds=timeoutSeconds
+                )
 
             except QueryNotSupportedError:
                 pass  # Let the superclass try
@@ -720,12 +739,16 @@
                 return succeed([])
 
         return BaseDirectoryService.recordsFromNonCompoundExpression(
-            self, expression
+            self, expression,
+            limitResults=limitResults, timeoutSeconds=timeoutSeconds
         )
 
 
     @inlineCallbacks
-    def recordsFromCompoundExpression(self, expression, recordTypes=None, records=None):
+    def recordsFromCompoundExpression(
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         """
         Returns records matching the CompoundExpression.  Because the
         local node doesn't perform Compound queries in a case insensitive
@@ -736,13 +759,16 @@
         """
 
         try:
-            query = self._queryFromCompoundExpression(expression, recordTypes=recordTypes)
+            query = self._queryFromCompoundExpression(
+                expression, recordTypes=recordTypes, limitResults=limitResults
+            )
 
         except QueryNotSupportedError:
             returnValue(
                 (
                     yield BaseDirectoryService.recordsFromCompoundExpression(
-                        self, expression, recordTypes=recordTypes
+                        self, expression, recordTypes=recordTypes,
+                        limitResults=limitResults, timeoutSeconds=timeoutSeconds
                     )
                 )
             )
@@ -752,7 +778,8 @@
         if self.localNode is not None:
 
             localRecords = yield self.localRecordsFromCompoundExpression(
-                expression, recordTypes=recordTypes
+                expression, recordTypes=recordTypes,
+                limitResults=limitResults, timeoutSeconds=timeoutSeconds
             )
             for localRecord in localRecords:
                 if localRecord not in results:
@@ -762,7 +789,10 @@
 
 
     @inlineCallbacks
-    def localRecordsFromCompoundExpression(self, expression, recordTypes=None):
+    def localRecordsFromCompoundExpression(
+        self, expression, recordTypes=None,
+        limitResults=None, timeoutSeconds=None
+    ):
         """
         Takes a CompoundExpression, and recursively goes through each
         MatchExpression, passing those specifically to the local node, and
@@ -780,17 +810,22 @@
 
             if isinstance(subExpression, CompoundExpression):
                 subRecords = yield self.localRecordsFromCompoundExpression(
-                    subExpression, recordTypes=recordTypes
+                    subExpression, recordTypes=recordTypes,
+                    limitResults=limitResults, timeoutSeconds=timeoutSeconds
                 )
 
             elif isinstance(subExpression, MatchExpression):
                 try:
                     subQuery = self._queryFromMatchExpression(
-                        subExpression, recordTypes=recordTypes, local=True
+                        subExpression, recordTypes=recordTypes, local=True,
+                        limitResults=limitResults
                     )
                 except UnsupportedRecordTypeError:
                     continue
-                subRecords = yield self._recordsFromQuery(subQuery)
+                subRecords = yield self._recordsFromQuery(
+                    subQuery,
+                    timeoutSeconds=timeoutSeconds
+                )
 
             else:
                 raise QueryNotSupportedError(
@@ -841,27 +876,35 @@
 
 
     @inlineCallbacks
-    def recordWithUID(self, uid):
+    def recordWithUID(self, uid, timeoutSeconds=None):
         returnValue(firstResult(
-            (yield self.recordsWithFieldValue(BaseFieldName.uid, uid))
+            (yield self.recordsWithFieldValue(
+                BaseFieldName.uid, uid, timeoutSeconds=timeoutSeconds
+            ))
         ))
 
 
     @inlineCallbacks
-    def recordWithGUID(self, guid):
+    def recordWithGUID(self, guid, timeoutSeconds=None):
         returnValue(firstResult(
-            (yield self.recordsWithFieldValue(BaseFieldName.guid, guid))
+            (yield self.recordsWithFieldValue(
+                BaseFieldName.guid, guid, timeoutSeconds=timeoutSeconds
+            ))
         ))
 
 
     @inlineCallbacks
-    def recordWithShortName(self, recordType, shortName):
+    def recordWithShortName(self, recordType, shortName, timeoutSeconds=None):
         try:
             query = self._queryFromMatchExpression(
                 MatchExpression(self.fieldName.shortNames, shortName),
-                recordTypes=(recordType,)
+                recordTypes=(recordType,),
+                limitResults=1
             )
-            results = yield self._recordsFromQuery(query)
+            results = yield self._recordsFromQuery(
+                query,
+                timeoutSeconds=timeoutSeconds
+            )
 
             try:
                 record = firstResult(results)
@@ -877,7 +920,7 @@
         except QueryNotSupportedError:
             # Let the superclass try
             returnValue((yield BaseDirectoryService.recordWithShortName(
-                self, recordType, shortName)))
+                self, recordType, shortName, timeoutSeconds=timeoutSeconds)))
 
         except UnsupportedRecordTypeError:
             returnValue(None)

Modified: twext/trunk/twext/who/test/test_directory.py
===================================================================
--- twext/trunk/twext/who/test/test_directory.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/test/test_directory.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -71,7 +71,8 @@
 
 
     def recordsFromNonCompoundExpression(
-        self, expression, recordTypes=None, records=None
+        self, expression, recordTypes=None, records=None,
+        limitResults=None, timeoutSeconds=None
     ):
         """
         This implementation handles three expressions:

Modified: twext/trunk/twext/who/test/test_index.py
===================================================================
--- twext/trunk/twext/who/test/test_index.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/test/test_index.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -351,6 +351,34 @@
 
 
     @inlineCallbacks
+    def test_unIndexedRecordsFromMatchExpression_limitResults(self):
+        """
+        Make sure limitResults does limit results.
+        """
+        service = self.noLoadServicePopulated()
+
+        records = yield service.unIndexedRecordsFromMatchExpression(
+            MatchExpression(
+                BaseFieldName.fullNames, u"A",
+                MatchType.startsWith
+            ),
+            recordTypes=None,
+            limitResults=1
+        )
+        self.assertEquals(len(records), 1)
+
+        records = yield service.unIndexedRecordsFromMatchExpression(
+            MatchExpression(
+                BaseFieldName.fullNames, u"A",
+                MatchType.startsWith
+            ),
+            recordTypes=None,
+            limitResults=1000
+        )
+        self.assertEquals(len(records), 2)
+
+
+    @inlineCallbacks
     def _test_recordsFromNonCompoundExpression(self, expression):
         service = self.noLoadServicePopulated()
         yield service.recordsFromNonCompoundExpression(expression)

Modified: twext/trunk/twext/who/test/test_xml.py
===================================================================
--- twext/trunk/twext/who/test/test_xml.py	2014-09-10 14:52:49 UTC (rev 13947)
+++ twext/trunk/twext/who/test/test_xml.py	2014-09-11 03:32:22 UTC (rev 13948)
@@ -185,7 +185,32 @@
         self.assertRecords(records, (u"__sagen__", u"__dre__"))
 
 
+    @inlineCallbacks
+    def test_limitResults(self):
+        """
+        Make sure limitResults does limit results.
+        """
 
+        service = self.service()
+
+        records = (
+            yield service.recordsWithRecordType(
+                service.recordType.user,
+                limitResults=3
+            )
+        )
+        self.assertEquals(len(records), 3)
+
+        records = (
+            yield service.recordsWithRecordType(
+                service.recordType.user,
+                limitResults=1000
+            )
+        )
+        self.assertEquals(len(records), 9)
+
+
+
 class DirectoryServiceRealmTestMixIn(object):
     def test_realmNameImmutable(self):
         def setRealmName():
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20140910/13085b3b/attachment-0001.html>


More information about the calendarserver-changes mailing list