[CalendarServer-changes] [8044] CalendarServer/trunk/twistedcaldav/directory

source_changes at macosforge.org source_changes at macosforge.org
Fri Sep 2 09:59:15 PDT 2011


Revision: 8044
          http://trac.macosforge.org/projects/calendarserver/changeset/8044
Author:   sagen at apple.com
Date:     2011-09-02 09:59:14 -0700 (Fri, 02 Sep 2011)
Log Message:
-----------
Escape any special LDAP filter characters prior to query

Modified Paths:
--------------
    CalendarServer/trunk/twistedcaldav/directory/ldapdirectory.py
    CalendarServer/trunk/twistedcaldav/directory/test/test_ldapdirectory.py

Modified: CalendarServer/trunk/twistedcaldav/directory/ldapdirectory.py
===================================================================
--- CalendarServer/trunk/twistedcaldav/directory/ldapdirectory.py	2011-09-02 01:25:04 UTC (rev 8043)
+++ CalendarServer/trunk/twistedcaldav/directory/ldapdirectory.py	2011-09-02 16:59:14 UTC (rev 8044)
@@ -41,6 +41,8 @@
 ]
 
 import ldap
+from ldap.filter import escape_filter_chars as ldapEsc
+
 try:
     # Note: PAM support is currently untested
     import PAM
@@ -250,11 +252,11 @@
             attrSet.add(self.partitionSchema["partitionIdAttr"])
         self.attrlist = list(attrSet)
 
-        self.typeRDNs = {}
+        self.typeDNs = {}
         for recordType in self.recordTypes():
-            self.typeRDNs[recordType] = ldap.dn.str2dn(
+            self.typeDNs[recordType] = ldap.dn.str2dn(
                 self.rdnSchema[recordType]["rdn"]
-            )
+            ) + self.base
 
         # Create LDAP connection
         self.log_info("Connecting to LDAP %s" % (repr(self.uri),))
@@ -287,7 +289,7 @@
     def listRecords(self, recordType):
 
         # Build base for this record Type
-        base = self.typeRDNs[recordType] + self.base
+        base = self.typeDNs[recordType]
 
         # Build filter
         filterstr = "(!(objectClass=organizationalUnit))"
@@ -417,6 +419,9 @@
                 attrlist=attrlist)
         except ldap.NO_SUCH_OBJECT:
             result = []
+        except ldap.FILTER_ERROR, e:
+            self.log_error("LDAP filter error: %s %s" % (e, filterstr))
+            result = []
         totalTime = time.time() - startTime
         if totalTime > self.warningThresholdSeconds:
             if filterstr and len(filterstr) > 100:
@@ -441,11 +446,11 @@
                 # fault in the members of group of name self.restrictToGroup
 
                 recordType = self.recordType_groups
-                base = self.typeRDNs[recordType] + self.base
+                base = self.typeDNs[recordType]
                 filterstr = "(cn=%s)" % (self.restrictToGroup,)
                 self.log_debug("Retrieving ldap record with base %s and filter %s." %
                     (ldap.dn.dn2str(base), filterstr))
-                result = self.ldap.search_s(ldap.dn.dn2str(base),
+                result = self.timedSearch(ldap.dn.dn2str(base),
                     ldap.SCOPE_SUBTREE, filterstr=filterstr, attrlist=self.attrlist)
 
                 if len(result) == 1:
@@ -492,7 +497,7 @@
                 continue
 
             recordType = self.recordType_groups
-            base = self.typeRDNs[recordType] + self.base
+            base = self.typeDNs[recordType]
             filterstr = "(%s=%s)" % (self.rdnSchema["guidAttr"], groupGUID)
 
             self.log_debug("Retrieving ldap record with base %s and filter %s." %
@@ -749,7 +754,7 @@
         guidAttr = self.rdnSchema["guidAttr"]
         for recordType in recordTypes:
             # Build base for this record Type
-            base = self.typeRDNs[recordType] + self.base
+            base = self.typeDNs[recordType]
 
             # Build filter
             filterstr = "(!(objectClass=organizationalUnit))"
@@ -768,7 +773,7 @@
                 filterstr = "(&%s(%s=%s))" % (
                     filterstr,
                     self.rdnSchema[recordType]["mapping"]["recordName"],
-                    indexKey
+                    ldapEsc(indexKey)
                 )
 
             elif indexType == self.INDEX_TYPE_CUA:
@@ -780,10 +785,10 @@
                         filterstr,
                         self.rdnSchema[recordType]["attr"],
                         email.partition("@")[0],
-                        email
+                        ldapEsc(email)
                     )
                 else:
-                    filterstr = "(&%s(mail=%s))" % (filterstr, email)
+                    filterstr = "(&%s(mail=%s))" % (filterstr, ldapEsc(email))
 
             elif indexType == self.INDEX_TYPE_AUTHID:
                 return
@@ -852,7 +857,7 @@
         guidAttr = self.rdnSchema["guidAttr"]
         for recordType in recordTypes:
 
-            base = self.typeRDNs[recordType] + self.base
+            base = self.typeDNs[recordType]
 
             if fields[0][0] == "dn":
                 # DN's are not an attribute that can be searched on by filter
@@ -915,7 +920,7 @@
 
         recordsByAlias = {}
 
-        groupsDN = self.typeRDNs[self.recordType_groups] + self.base
+        groupsDN = self.typeDNs[self.recordType_groups]
         memberIdAttr = self.groupSchema["memberIdAttr"]
 
         # First time through the loop we search using the attribute
@@ -959,7 +964,8 @@
                         # Members are identified by dn so we can take a short
                         # cut:  we know we only need to examine groups, and
                         # those will be children of the groups DN
-                        if not dnContainedIn(memberAlias, groupsDN):
+                        if not dnContainedIn(ldap.dn.str2dn(memberAlias),
+                            groupsDN):
                             continue
                     if memberAlias not in recordsByAlias:
                         valuesToFetch.add(memberAlias)
@@ -971,12 +977,21 @@
 
         returnValue(recordsByAlias.values())
 
+    def recordTypeForDN(self, dn):
+        """
+        Examine a dn to determine which recordType it belongs to
+        """
+        for recordType in self.recordTypes():
+            base = self.typeDNs[recordType]
+            if dnContainedIn(dn, base):
+                return recordType
+        return None
 
+
 def dnContainedIn(child, parent):
     """
     Return True if child dn is contained within parent dn, otherwise False.
     """
-    child = ldap.dn.str2dn(child)
     return child[-len(parent):] == parent
 
 
@@ -996,10 +1011,12 @@
         ldapField = mapping.get(field, None)
         if ldapField:
             if matchType == "starts-with":
-                value = "%s*" % (value,)
+                value = "%s*" % (ldapEsc(value),)
             elif matchType == "contains":
-                value = "*%s*" % (value,)
+                value = "*%s*" % (ldapEsc(value),)
             # otherwise it's an exact match
+            else:
+                value = ldapEsc(value)
             converted.append("(%s=%s)" % (ldapField, value))
 
     if len(converted) == 0:
@@ -1081,46 +1098,41 @@
 
         for memberId in self._memberIds:
 
-            for recordType in self.service.recordTypes():
+            if memberIdAttr:
 
-                if memberIdAttr:
-                    base = self.service.base
-                    filterstr = "(%s=%s)" % (memberIdAttr, memberId)
-                    self.log_debug("Retrieving subtree of %s with filter %s" %
-                        (ldap.dn.dn2str(base), filterstr),
-                        system="LdapDirectoryService")
-                    result = self.service.timedSearch(ldap.dn.dn2str(base),
-                        ldap.SCOPE_SUBTREE, filterstr=filterstr, attrlist=self.service.attrlist)
+                base = self.service.base
+                filterstr = "(%s=%s)" % (memberIdAttr, ldapEsc(memberId))
+                self.log_debug("Retrieving subtree of %s with filter %s" %
+                    (ldap.dn.dn2str(base), filterstr),
+                    system="LdapDirectoryService")
+                result = self.service.timedSearch(ldap.dn.dn2str(base),
+                    ldap.SCOPE_SUBTREE, filterstr=filterstr,
+                    attrlist=self.service.attrlist)
 
-                else:
-                    self.log_debug("Retrieving %s." % memberId,
-                        system="LdapDirectoryService")
-                    result = self.service.timedSearch(memberId,
-                        ldap.SCOPE_BASE, attrlist=self.service.attrlist)
+            else: # using DN
 
-                if result:
-                    # TODO: what about duplicates?
+                self.log_debug("Retrieving %s." % memberId,
+                    system="LdapDirectoryService")
+                result = self.service.timedSearch(memberId,
+                    ldap.SCOPE_BASE, attrlist=self.service.attrlist)
 
-                    dn, attrs = result.pop()
-                    self.log_debug("Retrieved: %s %s" % (dn,attrs))
+            if result:
 
-                    if recordType == self.service.recordType_users:
-                        shortName = self.service._getUniqueLdapAttribute(attrs,
-                            "uid", "userid")
-                    elif recordType in (
-                        self.service.recordType_groups,
-                        self.service.recordType_resources,
-                        self.service.recordType_locations
-                    ):
-                        shortName = self.service._getUniqueLdapAttribute(attrs,
-                            "cn")
+                dn, attrs = result.pop()
+                self.log_debug("Retrieved: %s %s" % (dn,attrs))
+                recordType = self.service.recordTypeForDN(ldap.dn.str2dn(dn))
+                if recordType is None:
+                    self.log_error("Unable to map %s to a record type" % (dn,))
+                    continue
 
-                    record = self.service.recordWithShortName(recordType,
-                        shortName)
-                    if record:
-                        results.append(record)
-                        break
+                shortName = self.service._getUniqueLdapAttribute(attrs,
+                    self.service.rdnSchema[recordType]["mapping"]["recordName"])
 
+                record = self.service.recordWithShortName(recordType,
+                    shortName)
+                if record:
+                    results.append(record)
+
         return results
 
     def groups(self):
@@ -1135,7 +1147,7 @@
         """ Fault in the groups of which this record is a member """
 
         recordType = self.service.recordType_groups
-        base = self.service.typeRDNs[recordType] + self.service.base
+        base = self.service.typeDNs[recordType]
 
         membersAttrs = []
         if self.service.groupSchema["membersAttr"]:

Modified: CalendarServer/trunk/twistedcaldav/directory/test/test_ldapdirectory.py
===================================================================
--- CalendarServer/trunk/twistedcaldav/directory/test/test_ldapdirectory.py	2011-09-02 01:25:04 UTC (rev 8043)
+++ CalendarServer/trunk/twistedcaldav/directory/test/test_ldapdirectory.py	2011-09-02 16:59:14 UTC (rev 8044)
@@ -52,6 +52,17 @@
                 },
                 {
                     "fields" : [
+                        ("fullName", "mor(", True, u"starts-with"),
+                        ("emailAddresses", "mor)", True, u"contains"),
+                        ("firstName", "mor*", True, u"exact"),
+                        ("lastName", "mor\\", True, u"starts-with")
+                    ],
+                    "operand" : "or",
+                    "recordType" : None,
+                    "expected" : "(|(cn=mor\\28*)(mail=*mor\\29*)(givenName=mor\\2a)(sn=mor\\5c*))"
+                },
+                {
+                    "fields" : [
                         ("fullName", "mor", True, u"starts-with"),
                     ],
                     "operand" : "or",
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20110902/447db817/attachment-0001.html>


More information about the calendarserver-changes mailing list