[CalendarServer-changes] [9530] CalendarServer/trunk/twext/enterprise/dal

source_changes at macosforge.org source_changes at macosforge.org
Mon Aug 6 10:40:15 PDT 2012


Revision: 9530
          http://trac.macosforge.org/projects/calendarserver/changeset/9530
Author:   cdaboo at apple.com
Date:     2012-08-06 10:40:14 -0700 (Mon, 06 Aug 2012)
Log Message:
-----------
Clean-up .In(...) DAL syntax.

Modified Paths:
--------------
    CalendarServer/trunk/twext/enterprise/dal/syntax.py
    CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py

Modified: CalendarServer/trunk/twext/enterprise/dal/syntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/syntax.py	2012-08-06 16:48:20 UTC (rev 9529)
+++ CalendarServer/trunk/twext/enterprise/dal/syntax.py	2012-08-06 17:40:14 UTC (rev 9530)
@@ -324,13 +324,15 @@
             "SQL expressions should not be tested for truth value in Python.")
 
 
-    def In(self, subselect):
+    def In(self, other):
         # Can't be Select.__contains__ because __contains__ gets __nonzero__
         # called on its result by the 'in' syntax.
-        if isinstance(subselect, ParameterSet):
-            return CompoundComparison(self, 'in', ConstantSet(subselect))
+        if isinstance(other, Parameter):
+            if not other.isSet():
+                raise DALError("Parameter in an In(...) expression must be a set of values.")
+            return CompoundComparison(self, 'in', Constant(other))
         else:
-            return CompoundComparison(self, 'in', subselect)
+            return CompoundComparison(self, 'in', other)
 
 
     def StartsWith(self, other):
@@ -380,27 +382,15 @@
 
 
     def subSQL(self, queryGenerator, allTables):
-        return SQLFragment(queryGenerator.placeholder.placeholder(), [self.value])
+        if isinstance(self.value, Parameter) and self.value.isSet():
+            return _inParens(_CommaList(
+                [SQLFragment(queryGenerator.placeholder.placeholder(), [self.value] if ctr == 0 else []) for ctr in range(self.value.len)]
+            ).subSQL(queryGenerator, allTables))
+        else:
+            return SQLFragment(queryGenerator.placeholder.placeholder(), [self.value])
 
 
 
-class ConstantSet(ExpressionSyntax):
-    def __init__(self, value):
-        self.value = value
-
-
-    def allColumns(self):
-        return []
-
-
-    def subSQL(self, queryGenerator, allTables):
-        
-        return _inParens(_CommaList(
-            [SQLFragment(queryGenerator.placeholder.placeholder(), [self.value] if ctr == 0 else []) for ctr in range(self.value.len)]
-        ).subSQL(queryGenerator, allTables))
-
-
-
 class NamedValue(ExpressionSyntax):
     """
     A constant within the database; something predefined, such as
@@ -1612,10 +1602,11 @@
         params = []
         for parameter in self.parameters:
             if isinstance(parameter, Parameter):
-                params.append(kw[parameter.name])
-            elif isinstance(parameter, ParameterSet):
-                for item in kw[parameter.name]:
-                    params.append(item)
+                if parameter.isSet():
+                    for item in kw[parameter.name]:
+                        params.append(item)
+                else:
+                    params.append(kw[parameter.name])
             else:
                 params.append(parameter)
         return SQLFragment(self.text, params)
@@ -1650,8 +1641,11 @@
 
 class Parameter(object):
 
-    def __init__(self, name):
+    def __init__(self, name, values=None):
         self.name = name
+        self.values = values
+        if self.values is not None:
+            self.len = len(values)
 
 
     def __eq__(self, param):
@@ -1670,31 +1664,11 @@
         return 'Parameter(%r)' % (self.name,)
 
 
+    def isSet(self):
+        return hasattr(self, "len")
 
-class ParameterSet(object):
 
-    def __init__(self, name, items):
-        self.name = name
-        self.len = len(items)
 
-
-    def __eq__(self, param):
-        if not isinstance(param, ParameterSet):
-            return NotImplemented
-        return self.name == param.name
-
-
-    def __ne__(self, param):
-        if not isinstance(param, ParameterSet):
-            return NotImplemented
-        return not self.__eq__(param)
-
-
-    def __repr__(self):
-        return 'ParameterSet(%r)' % (self.name,)
-
-
-
 # Common helpers:
 
 # current timestamp in UTC format.  Hack to support standard syntax for this,

Modified: CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py	2012-08-06 16:48:20 UTC (rev 9529)
+++ CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py	2012-08-06 17:40:14 UTC (rev 9530)
@@ -25,7 +25,7 @@
     TableMismatch, Parameter, Max, Len, NotEnoughValues,
     Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction,
     Union, Intersect, Except, SetExpression, DALError,
-    ResultAliasSyntax, Count, QueryGenerator, ParameterSet)
+    ResultAliasSyntax, Count, QueryGenerator)
 from twext.enterprise.dal.syntax import FixedPlaceholder, NumericPlaceholder
 from twext.enterprise.dal.syntax import Function
 from twext.enterprise.dal.syntax import SchemaSyntax
@@ -702,7 +702,7 @@
         
         items = set(('A', 'B'))
         self.assertEquals(
-            Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(ParameterSet("names", items))).toSQL().bind(names=items),
+            Select(From=self.schema.FOO, Where=self.schema.FOO.BAR.In(Parameter("names", items))).toSQL().bind(names=items),
             SQLFragment(
                 "select * from FOO where BAR in (?, ?)", ['A', 'B']))
 
@@ -710,7 +710,7 @@
             Select(
                 From=self.schema.FOO,
                 Where=(self.schema.FOO.BAZ == Parameter('P1')).And(
-                    self.schema.FOO.BAR.In(ParameterSet("names", items)
+                    self.schema.FOO.BAR.In(Parameter("names", items)
                 ))
             ).toSQL().bind(P1="P1", names=items),
             SQLFragment(
@@ -720,7 +720,7 @@
         self.assertEquals(
             Select(
                 From=self.schema.FOO,
-                Where=(self.schema.FOO.BAR.In(ParameterSet("names", items)).And(
+                Where=(self.schema.FOO.BAR.In(Parameter("names", items)).And(
                     self.schema.FOO.BAZ == Parameter('P2')
                 ))
             ).toSQL().bind(P2="P2", names=items),
@@ -732,7 +732,7 @@
             Select(
                 From=self.schema.FOO,
                 Where=(self.schema.FOO.BAZ == Parameter('P1')).Or(
-                    self.schema.FOO.BAR.In(ParameterSet("names", items)).And(
+                    self.schema.FOO.BAR.In(Parameter("names", items)).And(
                         self.schema.FOO.BAZ == Parameter('P2')
                 ))
             ).toSQL().bind(P1="P1", P2="P2", names=items),
@@ -740,7 +740,10 @@
                 "select * from FOO where BAZ = ? or BAR in (?, ?) and BAZ = ?", ['P1', 'A', 'B', 'P2']),
         )
 
+        # Check that a set argument is required
+        self.assertRaises(DALError, self.schema.FOO.BAR.In, Parameter("names"))
 
+
     def test_max(self):
         """
         L{Max}C{(column)} produces an object in the 'columns' clause that
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20120806/2e89f4cb/attachment-0001.html>


More information about the calendarserver-changes mailing list