[CalendarServer-changes] [8719] CalendarServer/trunk

source_changes at macosforge.org source_changes at macosforge.org
Thu Feb 16 11:00:18 PST 2012


Revision: 8719
          http://trac.macosforge.org/projects/calendarserver/changeset/8719
Author:   cdaboo at apple.com
Date:     2012-02-16 11:00:17 -0800 (Thu, 16 Feb 2012)
Log Message:
-----------
Internal API changes to use a QueryGenerator object to maintain persistent state whilst building up a single query.

Modified Paths:
--------------
    CalendarServer/trunk/twext/enterprise/dal/syntax.py
    CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py
    CalendarServer/trunk/txdav/common/datastore/sql_tables.py

Modified: CalendarServer/trunk/twext/enterprise/dal/syntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/syntax.py	2012-02-16 18:56:25 UTC (rev 8718)
+++ CalendarServer/trunk/twext/enterprise/dal/syntax.py	2012-02-16 19:00:17 UTC (rev 8719)
@@ -20,17 +20,17 @@
 """
 
 from itertools import count, repeat
+from functools import partial
 from operator import eq, ne
 
 from zope.interface import implements
 
 from twisted.internet.defer import succeed
 
+from twext.enterprise.dal.model import Schema, Table, Column, Sequence
 from twext.enterprise.ienterprise import POSTGRES_DIALECT, ORACLE_DIALECT
 from twext.enterprise.ienterprise import IDerivedParameter
-
 from twext.enterprise.util import mapOracleOutputType
-from twext.enterprise.dal.model import Schema, Table, Column, Sequence
 
 try:
     import cx_Oracle
@@ -45,29 +45,24 @@
     normally never be caught or ignored.
     """
 
-class ConnectionMetadata(object):
+class QueryPlaceholder(object):
     """
-    Representation of the metadata about the database connection required to
-    generate some SQL, for a single statement.  Contains information necessary
-    to generate place holder strings and determine the database dialect.
+    Representation of the placeholders required to generate some SQL, for a
+    single statement.  Contains information necessary
+    to generate place holder strings based on the database dialect.
     """
 
-    def __init__(self, dialect):
-        self.dialect = dialect
-
-
     def placeholder(self):
         raise NotImplementedError("See subclasses.")
 
 
 
-class FixedPlaceholder(ConnectionMetadata):
+class FixedPlaceholder(QueryPlaceholder):
     """
-    Metadata about a connection which uses a fixed string as its place holder.
+    Fixed string used as the place holder.
     """
 
-    def __init__(self, dialect, placeholder):
-        super(FixedPlaceholder, self).__init__(dialect)
+    def __init__(self, placeholder):
         self._placeholder = placeholder
 
 
@@ -76,10 +71,12 @@
 
 
 
-class NumericPlaceholder(ConnectionMetadata):
+class NumericPlaceholder(QueryPlaceholder):
+    """
+    Numeric counter used as the place holder.
+    """
 
-    def __init__(self, dialect):
-        super(NumericPlaceholder, self).__init__(dialect)
+    def __init__(self):
         self._next = count(1).next
 
 
@@ -88,14 +85,33 @@
 
 
 
-def defaultMetadata():
+def defaultPlaceholder():
     """
-    Generate a default L{ConnectionMetadata}
+    Generate a default L{QueryPlaceholder}
     """
-    return FixedPlaceholder(POSTGRES_DIALECT, '?')
+    return FixedPlaceholder('?')
 
 
 
+class QueryGenerator(object):
+    """
+    Maintains various pieces of transient information needed when building a
+    query. This includes the SQL dialect, the format of the place holder and
+    and automated id generator.
+    """
+    
+    def __init__(self, dialect=None, placeholder=None):
+        self.dialect = dialect if dialect else POSTGRES_DIALECT
+        if placeholder is None:
+            placeholder = defaultPlaceholder()
+        self.placeholder = placeholder
+
+        self.generatedID = count(1).next
+    
+    def nextGeneratedID(self):
+        return "genid_%d" % (self.generatedID(),)
+
+
 class TableMismatch(Exception):
     """
     A table in a statement did not match with a column.
@@ -117,18 +133,18 @@
     """
 
     _paramstyles = {
-        'pyformat': lambda dialect: FixedPlaceholder(dialect, "%s"),
+        'pyformat': partial(FixedPlaceholder, "%s"),
         'numeric': NumericPlaceholder
     }
 
 
-    def toSQL(self, metadata=None):
-        if metadata is None:
-            metadata = defaultMetadata()
-        return self._toSQL(metadata)
+    def toSQL(self, queryGenerator=None):
+        if queryGenerator is None:
+            queryGenerator = QueryGenerator()
+        return self._toSQL(queryGenerator)
 
 
-    def _extraVars(self, txn, metadata):
+    def _extraVars(self, txn, queryGenerator):
         """
         A hook for subclasses to provide additional keyword arguments to the
         C{bind} call when L{_Statement.on} is executed.  Currently this is used
@@ -138,7 +154,7 @@
         return {}
 
 
-    def _extraResult(self, result, outvars, metadata):
+    def _extraResult(self, result, outvars, queryGenerator):
         """
         A hook for subclasses to manipulate the results of 'on', after they've
         been retrieved by the database but before they've been given to
@@ -151,10 +167,10 @@
         @param outvars: a dictionary of extra variables returned by
             C{self._extraVars}.
 
-        @param metadata: information about the connection where the statement
+        @param queryGenerator: information about the connection where the statement
             was executed.
 
-        @type metadata: L{ConnectionMetadata} (a subclass thereof)
+        @type queryGenerator: L{QueryGenerator} (a subclass thereof)
 
         @return: the result to be returned from L{_Statement.on}.
 
@@ -181,14 +197,14 @@
         @rtype: a L{Deferred} firing a C{list} of records (C{tuple}s or
             C{list}s)
         """
-        metadata = self._paramstyles[txn.paramstyle](txn.dialect)
-        outvars = self._extraVars(txn, metadata)
+        queryGenerator = QueryGenerator(txn.dialect, self._paramstyles[txn.paramstyle]())
+        outvars = self._extraVars(txn, queryGenerator)
         kw.update(outvars)
-        fragment = self.toSQL(metadata).bind(**kw)
+        fragment = self.toSQL(queryGenerator).bind(**kw)
         result = txn.execSQL(fragment.text, fragment.parameters,
                              raiseOnZeroRowCount)
-        result = self._extraResult(result, outvars, metadata)
-        if metadata.dialect == ORACLE_DIALECT and result:
+        result = self._extraResult(result, outvars, queryGenerator)
+        if queryGenerator.dialect == ORACLE_DIALECT and result:
             result.addCallback(self._fixOracleNulls)
         return result
 
@@ -317,7 +333,6 @@
     def Contains(self, other):
         return CompoundComparison(self, "like", CompoundComparison(Constant('%'), '||', CompoundComparison(Constant(other), '||', Constant('%'))))
 
-
 class FunctionInvocation(ExpressionSyntax):
     def __init__(self, function, *args):
         self.function = function
@@ -335,10 +350,10 @@
         return list(ac())
 
 
-    def subSQL(self, metadata, allTables):
-        result = SQLFragment(self.function.nameFor(metadata))
+    def subSQL(self, queryGenerator, allTables):
+        result = SQLFragment(self.function.nameFor(queryGenerator))
         result.append(_inParens(
-            _commaJoined(_convert(arg).subSQL(metadata, allTables)
+            _commaJoined(_convert(arg).subSQL(queryGenerator, allTables)
                          for arg in self.args)))
         return result
 
@@ -353,8 +368,8 @@
         return []
 
 
-    def subSQL(self, metadata, allTables):
-        return SQLFragment(metadata.placeholder(), [self.value])
+    def subSQL(self, queryGenerator, allTables):
+        return SQLFragment(queryGenerator.placeholder.placeholder(), [self.value])
 
 
 
@@ -367,7 +382,7 @@
         self.name = name
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         return SQLFragment(self.name)
 
 
@@ -382,8 +397,8 @@
         self.oracleName = oracleName
 
 
-    def nameFor(self, metadata):
-        if metadata.dialect == ORACLE_DIALECT and self.oracleName is not None:
+    def nameFor(self, queryGenerator):
+        if queryGenerator.dialect == ORACLE_DIALECT and self.oracleName is not None:
             return self.oracleName
         return self.name
 
@@ -439,11 +454,11 @@
 
     modelType = Sequence
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         """
         Convert to an SQL fragment.
         """
-        if metadata.dialect == ORACLE_DIALECT:
+        if queryGenerator.dialect == ORACLE_DIALECT:
             fmt = "%s.nextval"
         else:
             fmt = "nextval('%s')"
@@ -490,14 +505,14 @@
         return Join(self, type, otherTableSyntax, on)
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         """
         Generate the L{SQLFragment} for this table's identification; this is
         for use in a 'from' clause.
         """
         # XXX maybe there should be a specific method which is only invoked
         # from the FROM clause, that only tables and joins would implement?
-        return SQLFragment(_nameForDialect(self.model.name, metadata.dialect))
+        return SQLFragment(_nameForDialect(self.model.name, queryGenerator.dialect))
 
 
     def __getattr__(self, attr):
@@ -561,12 +576,12 @@
     self-join.
     """
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         """
         Return an L{SQLFragment} with a string of the form C{'mytable myalias'}
         suitable for use in a FROM clause.
         """
-        result = super(TableAlias, self).subSQL(metadata, allTables)
+        result = super(TableAlias, self).subSQL(queryGenerator, allTables)
         result.append(SQLFragment(" " + self._aliasName(allTables)))
         return result
 
@@ -617,18 +632,18 @@
         self.on = on
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         stmt = SQLFragment()
-        stmt.append(self.leftSide.subSQL(metadata, allTables))
+        stmt.append(self.leftSide.subSQL(queryGenerator, allTables))
         stmt.text += ' '
         if self.type:
             stmt.text += self.type
             stmt.text += ' '
         stmt.text += 'join '
-        stmt.append(self.rightSide.subSQL(metadata, allTables))
+        stmt.append(self.rightSide.subSQL(queryGenerator, allTables))
         if self.type != 'cross':
             stmt.text += ' on '
-            stmt.append(self.on.subSQL(metadata, allTables))
+            stmt.append(self.on.subSQL(queryGenerator, allTables))
         return stmt
 
 
@@ -678,11 +693,11 @@
         return [self]
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         # XXX This, and 'model', could in principle conflict with column names.
         # Maybe do something about that.
         name = self.model.name
-        if metadata.dialect == ORACLE_DIALECT and name.lower() in _KEYWORDS:
+        if queryGenerator.dialect == ORACLE_DIALECT and name.lower() in _KEYWORDS:
             name = '"%s"' % (name,)
 
         if self._alwaysQualified:
@@ -711,20 +726,25 @@
 
 class ResultAliasSyntax(ExpressionSyntax):
     
-    def __init__(self, expression, alias):
+    def __init__(self, expression, alias=None):
         self.expression = expression
         self.alias = alias
 
+    def aliasName(self, queryGenerator):
+        if self.alias is None:
+            self.alias = queryGenerator.nextGeneratedID()
+        return self.alias
+
     def columnReference(self):
         return AliasReferenceSyntax(self)
 
     def allColumns(self):
         return self.expression.allColumns()
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         result = SQLFragment()
-        result.append(self.expression.subSQL(metadata, allTables))
-        result.append(SQLFragment(" %s" % (self.alias,)))
+        result.append(self.expression.subSQL(queryGenerator, allTables))
+        result.append(SQLFragment(" %s" % (self.aliasName(queryGenerator),)))
         return result
 
 
@@ -736,8 +756,8 @@
     def allColumns(self):
         return self.resultAlias.allColumns()
 
-    def subSQL(self, metadata, allTables):
-        return SQLFragment(self.resultAlias.alias)
+    def subSQL(self, queryGenerator, allTables):
+        return SQLFragment(self.resultAlias.aliasName(queryGenerator))
 
 
 class AliasedColumnSyntax(ColumnSyntax):
@@ -771,8 +791,8 @@
         self.b = b
 
 
-    def _subexpression(self, expr, metadata, allTables):
-        result = expr.subSQL(metadata, allTables)
+    def _subexpression(self, expr, queryGenerator, allTables):
+        result = expr.subSQL(queryGenerator, allTables)
         if self.op not in ('and', 'or') and isinstance(expr, Comparison):
             result = _inParens(result)
         return result
@@ -800,9 +820,9 @@
         super(NullComparison, self).__init__(a, op, None)
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         sqls = SQLFragment()
-        sqls.append(self.a.subSQL(metadata, allTables))
+        sqls.append(self.a.subSQL(queryGenerator, allTables))
         sqls.text += " is "
         if self.op != "=":
             sqls.text += "not "
@@ -821,13 +841,13 @@
         return self.a.allColumns() + self.b.allColumns()
 
 
-    def subSQL(self, metadata, allTables):
-        if ( metadata.dialect == ORACLE_DIALECT
+    def subSQL(self, queryGenerator, allTables):
+        if ( queryGenerator.dialect == ORACLE_DIALECT
              and isinstance(self.b, Constant) and self.b.value == ''
              and self.op in ('=', '!=') ):
-            return NullComparison(self.a, self.op).subSQL(metadata, allTables)
+            return NullComparison(self.a, self.op).subSQL(queryGenerator, allTables)
         stmt = SQLFragment()
-        result = self._subexpression(self.a, metadata, allTables)
+        result = self._subexpression(self.a, queryGenerator, allTables)
         if (isinstance(self.a, CompoundComparison)
             and self.a.op == 'or' and self.op == 'and'):
             result = _inParens(result)
@@ -835,7 +855,7 @@
 
         stmt.text += ' %s ' % (self.op,)
 
-        result = self._subexpression(self.b, metadata, allTables)
+        result = self._subexpression(self.b, queryGenerator, allTables)
         if (isinstance(self.b, CompoundComparison)
             and self.b.op == 'or' and self.op == 'and'):
             result = _inParens(result)
@@ -863,7 +883,7 @@
 
 class _AllColumns(object):
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         return SQLFragment('*')
 
 ALL_COLUMNS = _AllColumns()
@@ -876,7 +896,7 @@
         self.columns = columns
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         first = True
         cstatement = SQLFragment()
         for column in self.columns:
@@ -884,7 +904,7 @@
                 first = False
             else:
                 cstatement.append(SQLFragment(", "))
-            cstatement.append(column.subSQL(metadata, allTables))
+            cstatement.append(column.subSQL(queryGenerator, allTables))
         return cstatement
 
 
@@ -906,8 +926,8 @@
         self.columns = columns
 
 
-    def subSQL(self, metadata, allTables):
-        return _inParens(_commaJoined(c.subSQL(metadata, allTables)
+    def subSQL(self, queryGenerator, allTables):
+        return _inParens(_commaJoined(c.subSQL(queryGenerator, allTables)
                                       for c in self.columns))
 
 
@@ -943,15 +963,15 @@
         if self.optype not in (None, SetExpression.OPTYPE_ALL, SetExpression.OPTYPE_DISTINCT,):
             raise DALError("Must have either 'all' or 'distinct' in a set expression")
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         result = SQLFragment()
         for select in self.selects:
-            result.append(self.setOpSQL(metadata))
+            result.append(self.setOpSQL(queryGenerator))
             if self.optype == SetExpression.OPTYPE_ALL:
                 result.append(SQLFragment("ALL "))
             elif self.optype == SetExpression.OPTYPE_DISTINCT:
                 result.append(SQLFragment("DISTINCT "))
-            result.append(select.subSQL(metadata, allTables))
+            result.append(select.subSQL(queryGenerator, allTables))
         return result
 
     def allColumns(self):
@@ -961,24 +981,24 @@
     """
     A UNION construct used inside a SELECT.
     """
-    def setOpSQL(self, metadata):
+    def setOpSQL(self, queryGenerator):
         return SQLFragment(" UNION ")
 
 class Intersect(SetExpression):
     """
     An INTERSECT construct used inside a SELECT.
     """
-    def setOpSQL(self, metadata):
+    def setOpSQL(self, queryGenerator):
         return SQLFragment(" INTERSECT ")
 
 class Except(SetExpression):
     """
     An EXCEPT construct used inside a SELECT.
     """
-    def setOpSQL(self, metadata):
-        if metadata.dialect == POSTGRES_DIALECT:
+    def setOpSQL(self, queryGenerator):
+        if queryGenerator.dialect == POSTGRES_DIALECT:
             return SQLFragment(" EXCEPT ")
-        elif metadata.dialect == ORACLE_DIALECT:
+        elif queryGenerator.dialect == ORACLE_DIALECT:
             return SQLFragment(" MINUS ")
         else:
             raise NotImplementedError("Unsupported dialect")
@@ -1032,7 +1052,7 @@
         return CompoundComparison(other, '=', self)
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         """
         @return: a 'select' statement with placeholders and arguments
 
@@ -1046,11 +1066,11 @@
         if self.Distinct:
             stmt.text += "distinct "
         allTables = self.From.tables()
-        stmt.append(self.columns.subSQL(metadata, allTables))
+        stmt.append(self.columns.subSQL(queryGenerator, allTables))
         stmt.text += " from "
-        stmt.append(self.From.subSQL(metadata, allTables))
+        stmt.append(self.From.subSQL(queryGenerator, allTables))
         if self.Where is not None:
-            wherestmt = self.Where.subSQL(metadata, allTables)
+            wherestmt = self.Where.subSQL(queryGenerator, allTables)
             stmt.text += " where "
             stmt.append(wherestmt)
         if self.GroupBy is not None:
@@ -1061,14 +1081,14 @@
                     fst = False
                 else:
                     stmt.text += ', '
-                stmt.append(subthing.subSQL(metadata, allTables))
+                stmt.append(subthing.subSQL(queryGenerator, allTables))
         if self.Having is not None:
-            havingstmt = self.Having.subSQL(metadata, allTables)
+            havingstmt = self.Having.subSQL(queryGenerator, allTables)
             stmt.text += " having "
             stmt.append(havingstmt)
         if self.SetExpression is not None:
             stmt.append(SQLFragment(")"))
-            stmt.append(self.SetExpression.subSQL(metadata, allTables))
+            stmt.append(self.SetExpression.subSQL(queryGenerator, allTables))
         if self.OrderBy is not None:
             stmt.text += " order by "
             fst = True
@@ -1077,7 +1097,7 @@
                     fst = False
                 else:
                     stmt.text += ', '
-                stmt.append(subthing.subSQL(metadata, allTables))
+                stmt.append(subthing.subSQL(queryGenerator, allTables))
             if self.Ascending is not None:
                 if self.Ascending:
                     kw = " asc"
@@ -1089,8 +1109,8 @@
             if self.NoWait:
                 stmt.text += " nowait"
         if self.Limit is not None:
-            limitConst = Constant(self.Limit).subSQL(metadata, allTables)
-            if metadata.dialect == ORACLE_DIALECT:
+            limitConst = Constant(self.Limit).subSQL(queryGenerator, allTables)
+            if queryGenerator.dialect == ORACLE_DIALECT:
                 wrapper = SQLFragment("select * from (")
                 wrapper.append(stmt)
                 wrapper.append(SQLFragment(") where ROWNUM <= "))
@@ -1101,16 +1121,13 @@
         return stmt
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         result = SQLFragment("(")
-        result.append(self.toSQL(metadata))
+        result.append(self.toSQL(queryGenerator))
         result.append(SQLFragment(")"))
         if self.As is not None:
             if self.As == "":
-                if not hasattr(metadata, "generated_table_aliases"):
-                    metadata.generated_table_aliases = 1
-                self.As = "alias_%d" % (metadata.generated_table_aliases,)
-                metadata.generated_table_aliases += 1
+                self.As = queryGenerator.nextGeneratedID()
             result.append(SQLFragment(" %s" % (self.As,)))
         return result
 
@@ -1197,8 +1214,8 @@
         self.subfragments = subfragments
 
 
-    def subSQL(self, metadata, allTables):
-        return _commaJoined(f.subSQL(metadata, allTables)
+    def subSQL(self, queryGenerator, allTables):
+        return _commaJoined(f.subSQL(queryGenerator, allTables)
                             for f in self.subfragments)
 
 
@@ -1208,15 +1225,15 @@
     Common functionality of Insert/Update/Delete statements.
     """
 
-    def _returningClause(self, metadata, stmt, allTables):
+    def _returningClause(self, queryGenerator, stmt, allTables):
         """
         Add a dialect-appropriate 'returning' clause to the end of the given SQL
         statement.
 
-        @param metadata: describes the database we are generating the statement
+        @param queryGenerator: describes the database we are generating the statement
             for.
 
-        @type metadata: L{ConnectionMetadata}
+        @type queryGenerator: L{QueryGenerator}
 
         @param stmt: the SQL fragment generated without the 'returning' clause
         @type stmt: L{SQLFragment}
@@ -1231,15 +1248,15 @@
             retclause = _CommaList(retclause)
         if retclause is not None:
             stmt.text += ' returning '
-            stmt.append(retclause.subSQL(metadata, allTables))
-            if metadata.dialect == ORACLE_DIALECT:
+            stmt.append(retclause.subSQL(queryGenerator, allTables))
+            if queryGenerator.dialect == ORACLE_DIALECT:
                 stmt.text += ' into '
                 params = []
                 retvals = self._returnAsList()
                 for n, _ignore_v in enumerate(retvals):
                     params.append(
                         Constant(Parameter("oracle_out_" + str(n)))
-                        .subSQL(metadata, allTables)
+                        .subSQL(queryGenerator, allTables)
                     )
                 stmt.append(_commaJoined(params))
         return stmt
@@ -1252,19 +1269,19 @@
             return self.Return
 
 
-    def _extraVars(self, txn, metadata):
+    def _extraVars(self, txn, queryGenerator):
         if self.Return is None:
             return []
         result = []
         rvars = self._returnAsList()
-        if metadata.dialect == ORACLE_DIALECT:
+        if queryGenerator.dialect == ORACLE_DIALECT:
             for n, v in enumerate(rvars):
                 result.append(("oracle_out_" + str(n), _OracleOutParam(v)))
         return result
 
 
-    def _extraResult(self, result, outvars, metadata):
-        if metadata.dialect == ORACLE_DIALECT and self.Return is not None:
+    def _extraResult(self, result, outvars, queryGenerator):
+        if queryGenerator.dialect == ORACLE_DIALECT and self.Return is not None:
             def processIt(shouldBeNone):
                 result = [[v.value for _ignore_k, v in outvars]]
                 return result
@@ -1323,7 +1340,7 @@
                     (', '.join([c.name for c in unspecified])))
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         """
         @return: a 'insert' statement with placeholders and arguments
 
@@ -1332,7 +1349,7 @@
         columnsAndValues = self.columnMap.items()
         tableModel = columnsAndValues[0][0].model.table
         specifiedColumnModels = [x.model for x in self.columnMap.keys()]
-        if metadata.dialect == ORACLE_DIALECT:
+        if queryGenerator.dialect == ORACLE_DIALECT:
             # See test_nextSequenceDefaultImplicitExplicitOracle.
             for column in tableModel.columns:
                 if isinstance(column.default, Sequence):
@@ -1345,16 +1362,16 @@
                                key=lambda (c, v): c.model.name)
         allTables = []
         stmt = SQLFragment('insert into ')
-        stmt.append(TableSyntax(tableModel).subSQL(metadata, allTables))
+        stmt.append(TableSyntax(tableModel).subSQL(queryGenerator, allTables))
         stmt.append(SQLFragment(" "))
         stmt.append(_inParens(_commaJoined(
-            [c.subSQL(metadata, allTables) for (c, _ignore_v) in
+            [c.subSQL(queryGenerator, allTables) for (c, _ignore_v) in
              sortedColumns])))
         stmt.append(SQLFragment(" values "))
         stmt.append(_inParens(_commaJoined(
-            [_convert(v).subSQL(metadata, allTables)
+            [_convert(v).subSQL(queryGenerator, allTables)
              for (c, v) in sortedColumns])))
-        return self._returningClause(metadata, stmt, allTables)
+        return self._returningClause(queryGenerator, stmt, allTables)
 
 
 
@@ -1383,7 +1400,7 @@
         self.Return = Return
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         """
         @return: a 'insert' statement with placeholders and arguments
 
@@ -1395,20 +1412,20 @@
         result = SQLFragment('update ')
         result.append(
             TableSyntax(sortedColumns[0][0].model.table).subSQL(
-                metadata, allTables)
+                queryGenerator, allTables)
         )
         result.text += ' set '
         result.append(
             _commaJoined(
-                [c.subSQL(metadata, allTables).append(
-                    SQLFragment(" = ").subSQL(metadata, allTables)
-                ).append(_convert(v).subSQL(metadata, allTables))
+                [c.subSQL(queryGenerator, allTables).append(
+                    SQLFragment(" = ").subSQL(queryGenerator, allTables)
+                ).append(_convert(v).subSQL(queryGenerator, allTables))
                     for (c, v) in sortedColumns]
             )
         )
         result.append(SQLFragment( ' where '))
-        result.append(self.Where.subSQL(metadata, allTables))
-        return self._returningClause(metadata, result, allTables)
+        result.append(self.Where.subSQL(queryGenerator, allTables))
+        return self._returningClause(queryGenerator, result, allTables)
 
 
 
@@ -1426,15 +1443,15 @@
         self.Return = Return
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         result = SQLFragment()
         allTables = self.From.tables()
         result.text += 'delete from '
-        result.append(self.From.subSQL(metadata, allTables))
+        result.append(self.From.subSQL(queryGenerator, allTables))
         if self.Where is not None:
             result.text += ' where '
-            result.append(self.Where.subSQL(metadata, allTables))
-        return self._returningClause(metadata, result, allTables)
+            result.append(self.Where.subSQL(queryGenerator, allTables))
+        return self._returningClause(queryGenerator, result, allTables)
 
 
 
@@ -1465,9 +1482,9 @@
         return cls(table, 'exclusive')
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         return SQLFragment('lock table ').append(
-            self.table.subSQL(metadata, [self.table])).append(
+            self.table.subSQL(queryGenerator, [self.table])).append(
             SQLFragment(' in %s mode' % (self.mode,)))
 
 
@@ -1481,7 +1498,7 @@
         self.name = name
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         return SQLFragment('savepoint %s' % (self.name,))
 
 
@@ -1494,7 +1511,7 @@
         self.name = name
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         return SQLFragment('rollback to savepoint %s' % (self.name,))
 
 
@@ -1507,7 +1524,7 @@
         self.name = name
 
 
-    def _toSQL(self, metadata):
+    def _toSQL(self, queryGenerator):
         return SQLFragment('release savepoint %s' % (self.name,))
 
 
@@ -1588,7 +1605,7 @@
         return self.__class__.__name__ + repr((self.text, self.parameters))
 
 
-    def subSQL(self, metadata, allTables):
+    def subSQL(self, queryGenerator, allTables):
         return self
 
 

Modified: CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py	2012-02-16 18:56:25 UTC (rev 8718)
+++ CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py	2012-02-16 19:00:17 UTC (rev 8719)
@@ -18,29 +18,26 @@
 Tests for L{twext.enterprise.dal.syntax}
 """
 
+from twext.enterprise.dal import syntax
 from twext.enterprise.dal.parseschema import addSQLToSchema
-from twext.enterprise.dal import syntax
 from twext.enterprise.dal.syntax import (
     Select, Insert, Update, Delete, Lock, SQLFragment,
     TableMismatch, Parameter, Max, Len, NotEnoughValues,
     Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction,
     Union, Intersect, Except, SetExpression, DALError,
-    ResultAliasSyntax)
-
+    ResultAliasSyntax, Count, QueryGenerator)
+from twext.enterprise.dal.syntax import FixedPlaceholder, NumericPlaceholder
 from twext.enterprise.dal.syntax import Function
-
-from twext.enterprise.dal.syntax import FixedPlaceholder, NumericPlaceholder
+from twext.enterprise.dal.syntax import SchemaSyntax
+from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper
 from twext.enterprise.ienterprise import POSTGRES_DIALECT, ORACLE_DIALECT
+from twext.enterprise.test.test_adbapi2 import ConnectionPoolHelper
+from twext.enterprise.test.test_adbapi2 import NetworkedPoolHelper
 from twext.enterprise.test.test_adbapi2 import resultOf
 from twisted.internet.defer import succeed
-from twext.enterprise.dal.test.test_parseschema import SchemaTestHelper
-from twext.enterprise.dal.syntax import SchemaSyntax
-from twext.enterprise.test.test_adbapi2 import ConnectionPoolHelper
-from twext.enterprise.test.test_adbapi2 import NetworkedPoolHelper
 from twisted.trial.unittest import TestCase
 
 
-
 class _FakeTransaction(object):
     """
     An L{IAsyncTransaction} that provides the relevant metadata for SQL
@@ -128,7 +125,7 @@
         """
         self.assertEquals(Select(From=self.schema.FOO,
                                  Where=self.schema.FOO.BAR == 1).toSQL(
-                                 FixedPlaceholder(POSTGRES_DIALECT, "$$")),
+                                 QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("$$"))),
                           SQLFragment("select * from FOO where BAR = $$", [1]))
 
 
@@ -193,13 +190,13 @@
         self.assertEquals(Select(
             From=self.schema.FOO,
             Where=self.schema.FOO.BAR == ''
-        ).toSQL(NumericPlaceholder(ORACLE_DIALECT)),
+        ).toSQL(QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())),
             SQLFragment(
                 "select * from FOO where BAR is null", []))
         self.assertEquals(Select(
             From=self.schema.FOO,
             Where=self.schema.FOO.BAR != ''
-        ).toSQL(NumericPlaceholder(ORACLE_DIALECT)),
+        ).toSQL(QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())),
             SQLFragment(
                 "select * from FOO where BAR is not null", []))
 
@@ -502,7 +499,7 @@
                         Where=(self.schema.FOO.BAR == 2),
                     ),
                 ),
-            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?)", [1, 2]))
 
@@ -518,7 +515,7 @@
                     ),
                     optype=SetExpression.OPTYPE_ALL
                 ),
-            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "(select * from FOO where BAR = ?) INTERSECT ALL (select * from FOO where BAR = ?)", [1, 2]))
 
@@ -539,7 +536,7 @@
                     ),
                     optype=SetExpression.OPTYPE_DISTINCT,
                 ),
-            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "(select * from FOO) EXCEPT DISTINCT (select * from FOO where BAR = ?) EXCEPT DISTINCT (select * from FOO where BAR = ?)", [2, 3]))
 
@@ -559,7 +556,7 @@
                         ),
                     ),
                 ),
-            ).toSQL(FixedPlaceholder(ORACLE_DIALECT, "?")),
+            ).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "(select * from FOO) MINUS ((select * from FOO where BAR = ?) MINUS (select * from FOO where BAR = ?))", [2, 3]))
 
@@ -575,7 +572,7 @@
                     ),
                 ),
                 OrderBy=self.schema.FOO.BAR,
-            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            ).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?) order by BAR", [1, 2]))
 
@@ -591,10 +588,18 @@
                 From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ))
             ).toSQL(),
             SQLFragment(
-                "select max(QUX) from (select QUX from BOZ) alias_1"))
+                "select max(QUX) from (select QUX from BOZ) genid_1"))
 
         self.assertEquals(
             Select(
+                [Count(self.schema.BOZ.QUX)],
+                From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ))
+            ).toSQL(),
+            SQLFragment(
+                "select count(QUX) from (select QUX from BOZ) genid_1"))
+
+        self.assertEquals(
+            Select(
                 [Max(self.schema.BOZ.QUX)],
                 From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ, As="alias_BAR")),
             ).toSQL(),
@@ -625,7 +630,7 @@
                 )
             ).toSQL(),
             SQLFragment(
-                "select max(BAR) from ((select BAR from FOO where BAR = ?) UNION (select BAR from FOO where BAR = ?)) alias_1", [1, 2]))
+                "select max(BAR) from ((select BAR from FOO where BAR = ?) UNION (select BAR from FOO where BAR = ?)) genid_1", [1, 2]))
 
     def test_selectColumnAliases(self):
         """
@@ -640,28 +645,28 @@
 
         self.assertEquals(
             Select(
-                [ResultAliasSyntax(Max(self.schema.BOZ.QUX), "MAX_QUX")],
+                [ResultAliasSyntax(Max(self.schema.BOZ.QUX))],
                 From=self.schema.BOZ
             ).toSQL(),
-            SQLFragment("select max(QUX) MAX_QUX from BOZ"))
+            SQLFragment("select max(QUX) genid_1 from BOZ"))
 
-        alias = ResultAliasSyntax(Max(self.schema.BOZ.QUX), "MAX_QUX")
+        alias = ResultAliasSyntax(Max(self.schema.BOZ.QUX))
         self.assertEquals(
             Select([alias.columnReference()],
                 From=Select(
                     [alias],
                     From=self.schema.BOZ)
             ).toSQL(),
-            SQLFragment("select MAX_QUX from (select max(QUX) MAX_QUX from BOZ) alias_1"))
+            SQLFragment("select genid_1 from (select max(QUX) genid_1 from BOZ) genid_2"))
 
-        alias = ResultAliasSyntax(Len(self.schema.BOZ.QUX), "LEN_QUX")
+        alias = ResultAliasSyntax(Len(self.schema.BOZ.QUX))
         self.assertEquals(
             Select([alias.columnReference()],
                 From=Select(
                     [alias],
                     From=self.schema.BOZ)
             ).toSQL(),
-            SQLFragment("select LEN_QUX from (select character_length(QUX) LEN_QUX from BOZ) alias_1"))
+            SQLFragment("select genid_1 from (select character_length(QUX) genid_1 from BOZ) genid_2"))
 
 
     def test_inSubSelect(self):
@@ -842,7 +847,7 @@
             Insert({self.schema.FOO.BAR: 40,
                     self.schema.FOO.BAZ: 50},
                    Return=(self.schema.FOO.BAR, self.schema.FOO.BAZ)).toSQL(
-                       NumericPlaceholder(ORACLE_DIALECT)
+                       QueryGenerator(ORACLE_DIALECT, NumericPlaceholder())
                    ),
             SQLFragment(
                 "insert into FOO (BAR, BAZ) values (:1, :2) returning BAR, BAZ"
@@ -875,7 +880,7 @@
         self.assertEquals(
             Insert({self.schema.LEVELS.ACCESS: 1,
                     self.schema.LEVELS.USERNAME:
-                    "hi"}).toSQL(FixedPlaceholder(ORACLE_DIALECT, "?")),
+                    "hi"}).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 'insert into LEVELS ("ACCESS", USERNAME) values (?, ?)',
                 [1, "hi"])
@@ -883,7 +888,7 @@
         self.assertEquals(
             Insert({self.schema.LEVELS.ACCESS: 1,
                     self.schema.LEVELS.USERNAME:
-                    "hi"}).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+                    "hi"}).toSQL(QueryGenerator(POSTGRES_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 'insert into LEVELS (ACCESS, USERNAME) values (?, ?)',
                 [1, "hi"])
@@ -1055,7 +1060,7 @@
         self.assertEquals(
             Select([self.schema.FOO.BAR],
                    From=self.schema.FOO,
-                   Limit=123).toSQL(FixedPlaceholder(ORACLE_DIALECT, "?")),
+                   Limit=123).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "select * from (select BAR from FOO) "
                 "where ROWNUM <= ?", [123])
@@ -1108,7 +1113,7 @@
         self.assertEquals(
             Insert({self.schema.BOZ.QUX:
                     self.schema.A_SEQ}).toSQL(
-                        FixedPlaceholder(ORACLE_DIALECT, "?")),
+                        QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
             SQLFragment("insert into BOZ (QUX) values (A_SEQ.nextval)", []))
 
 
@@ -1125,7 +1130,7 @@
         )
         self.assertEquals(
             Insert({self.schema.DFLTR.a: 'hello'}).toSQL(
-                FixedPlaceholder(ORACLE_DIALECT, "?")
+                QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))
             ),
             SQLFragment("insert into DFLTR (a, b) values "
                         "(?, A_SEQ.nextval)", ['hello']),
@@ -1134,7 +1139,7 @@
         self.assertEquals(
             Insert({self.schema.DFLTR.a: 'hello',
                     self.schema.DFLTR.b: self.schema.A_SEQ}).toSQL(
-                FixedPlaceholder(ORACLE_DIALECT, "?")
+                QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))
             ),
             SQLFragment("insert into DFLTR (a, b) values "
                         "(?, A_SEQ.nextval)", ['hello']),
@@ -1308,7 +1313,7 @@
         )
         vvl = self.schema.veryveryveryveryveryveryveryverylong
         self.assertEquals(
-            Insert({vvl.foo: 1}).toSQL(FixedPlaceholder(ORACLE_DIALECT, "?")),
+            Insert({vvl.foo: 1}).toSQL(QueryGenerator(ORACLE_DIALECT, FixedPlaceholder("?"))),
             SQLFragment(
                 "insert into veryveryveryveryveryveryveryve (foo) values "
                 "(?)", [1]

Modified: CalendarServer/trunk/txdav/common/datastore/sql_tables.py
===================================================================
--- CalendarServer/trunk/txdav/common/datastore/sql_tables.py	2012-02-16 18:56:25 UTC (rev 8718)
+++ CalendarServer/trunk/txdav/common/datastore/sql_tables.py	2012-02-16 19:00:17 UTC (rev 8719)
@@ -1,6 +1,6 @@
 # -*- test-case-name: txdav.common.datastore.test.test_sql_tables -*-
 ##
-# Copyright (c) 2010 Apple Inc. All rights reserved.
+# Copyright (c) 2010-2012 Apple Inc. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@
 """
 
 from twisted.python.modules import getModule
-from twext.enterprise.dal.syntax import SchemaSyntax
+from twext.enterprise.dal.syntax import SchemaSyntax, QueryGenerator
 from twext.enterprise.dal.model import NO_DEFAULT
 from twext.enterprise.dal.model import Sequence, ProcedureCall
 from twext.enterprise.dal.syntax import FixedPlaceholder
@@ -297,7 +297,7 @@
 
         out.write('\n);\n\n')
 
-        fakeMeta = FixedPlaceholder(ORACLE_DIALECT, '%s')
+        fakeQueryGenerator = QueryGenerator(ORACLE_DIALECT, FixedPlaceholder('%s'))
         def quoted(x):
             if isinstance(x, (str, unicode)):
                 return ''.join(["'", x.replace("'", "''"), "'"])
@@ -309,7 +309,7 @@
                 [(getattr(table, cmodel.name), val)
                  for (cmodel, val) in row.items()]
             )
-            fragment = Insert(cmap).toSQL(fakeMeta)
+            fragment = Insert(cmap).toSQL(fakeQueryGenerator)
             out.write(
                 fragment.text % tuple([quoted(param)
                                        for param in fragment.parameters]),
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20120216/c61fadce/attachment-0001.html>


More information about the calendarserver-changes mailing list