[CalendarServer-changes] [7125] CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal

source_changes at macosforge.org source_changes at macosforge.org
Mon Mar 7 19:01:55 PST 2011


Revision: 7125
          http://trac.macosforge.org/projects/calendarserver/changeset/7125
Author:   glyph at apple.com
Date:     2011-03-07 19:01:55 -0800 (Mon, 07 Mar 2011)
Log Message:
-----------
refactor placeholder/quote argument to more generic 'metadata' argument so that more behavior can be pushed through to different layers more easily

Modified Paths:
--------------
    CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/syntax.py
    CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/test/test_sqlsyntax.py

Modified: CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/syntax.py
===================================================================
--- CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/syntax.py	2011-03-08 03:01:43 UTC (rev 7124)
+++ CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/syntax.py	2011-03-08 03:01:55 UTC (rev 7125)
@@ -19,9 +19,64 @@
 Syntax wrappers and generators for SQL.
 """
 
+import itertools
+
+from twext.enterprise.ienterprise import POSTGRES_DIALECT
+
 from twext.enterprise.dal.model import Schema, Table, Column, Sequence
 
 
+class ConnectionMetadata(object):
+    """
+    Representation of the metadata about the database connection required to
+    generate some SQL, for a single statement.  Contains information necessary
+    to generate placeholder strings and determine the database dialect.
+    """
+
+    def __init__(self, dialect):
+        self.dialect = dialect
+
+
+    def placeholder(self):
+        raise NotImplementedError("See subclasses.")
+
+
+
+class FixedPlaceholder(ConnectionMetadata):
+    """
+    Metadata about a connection which uses a fixed string as its placeholder.
+    """
+
+    def __init__(self, dialect, placeholder):
+        super(FixedPlaceholder, self).__init__(dialect)
+        self._placeholder = placeholder
+
+
+    def placeholder(self):
+        return self._placeholder
+
+
+
+class NumericPlaceholder(ConnectionMetadata):
+
+    def __init__(self, dialect):
+        super(NumericPlaceholder, self).__init__(dialect)
+        self._next = itertools.count().next
+
+
+    def placeholder(self):
+        return ':' + str(self._next())
+
+
+
+def defaultMetadata():
+    """
+    Generate a default L{ConnectionMetadata}
+    """
+    return FixedPlaceholder(POSTGRES_DIALECT, '?')
+
+
+
 class TableMismatch(Exception):
     """
     A table in a statement did not match with a column.
@@ -43,16 +98,24 @@
     """
 
     _paramstyles = {
-        'pyformat': ('%s', lambda s: s.replace("%", "%%"))
+        'pyformat': lambda dialect: FixedPlaceholder(dialect, "%s"),
+        #'numeric': NumericPlaceholder
     }
 
+
+    def toSQL(self, metadata=None):
+        if metadata is None:
+            metadata = defaultMetadata()
+        return self._toSQL(metadata)
+
+
     def on(self, txn, raiseOnZeroRowCount=None, **kw):
         """
         Execute this statement on a given L{IAsyncTransaction} and return the
         resulting L{Deferred}.
         """
-        placeholder, quote = self._paramstyles[txn.paramstyle]
-        fragment = self.toSQL(placeholder, quote).bind(**kw)
+        metadata = self._paramstyles[txn.paramstyle](txn.dialect)
+        fragment = self.toSQL(metadata).bind(**kw)
         return txn.execSQL(fragment.text, fragment.parameters,
                            raiseOnZeroRowCount)
 
@@ -141,10 +204,10 @@
         return list(ac())
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         result = SQLFragment(self.name)
         result.append(_inParens(
-            _commaJoined(_convert(arg).subSQL(placeholder, quote, allTables)
+            _commaJoined(_convert(arg).subSQL(metadata, allTables)
                          for arg in self.args)))
         return result
 
@@ -159,8 +222,8 @@
         return []
 
 
-    def subSQL(self, placeholder, quote, allTables):
-        return SQLFragment(placeholder, [self.value])
+    def subSQL(self, metadata, allTables):
+        return SQLFragment(metadata.placeholder(), [self.value])
 
 
 
@@ -173,7 +236,7 @@
         self.name = name
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         return SQLFragment(self.name)
 
 
@@ -236,7 +299,7 @@
 
     modelType = Sequence
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         """
         Convert to an SQL fragment.
         """
@@ -257,7 +320,7 @@
         return Join(self, type, otherTableSyntax, on)
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         """
         For use in a 'from' clause.
         """
@@ -319,18 +382,18 @@
         self.on = on
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         stmt = SQLFragment()
-        stmt.append(self.leftSide.subSQL(placeholder, quote, allTables))
+        stmt.append(self.leftSide.subSQL(metadata, allTables))
         stmt.text += ' '
         if self.type:
             stmt.text += self.type
             stmt.text += ' '
         stmt.text += 'join '
-        stmt.append(self.rightSide.subSQL(placeholder, quote, allTables))
+        stmt.append(self.rightSide.subSQL(metadata, allTables))
         if self.type != 'cross':
             stmt.text += ' on '
-            stmt.append(self.on.subSQL(placeholder, quote, allTables))
+            stmt.append(self.on.subSQL(metadata, allTables))
         return stmt
 
 
@@ -358,7 +421,7 @@
         return [self]
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         # XXX This, and 'model', could in principle conflict with column names.
         # Maybe do something about that.
         for tableSyntax in allTables:
@@ -379,8 +442,8 @@
         self.b = b
 
 
-    def _subexpression(self, expr, placeholder, quote, allTables):
-        result = expr.subSQL(placeholder, quote, allTables)
+    def _subexpression(self, expr, metadata, allTables):
+        result = expr.subSQL(metadata, allTables)
         if self.op not in ('and', 'or') and isinstance(expr, Comparison):
             result = _inParens(result)
         return result
@@ -408,9 +471,9 @@
         super(NullComparison, self).__init__(a, op, None)
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         sqls = SQLFragment()
-        sqls.append(self.a.subSQL(placeholder, quote, allTables))
+        sqls.append(self.a.subSQL(metadata, allTables))
         sqls.text += " is "
         if self.op != "=":
             sqls.text += "not "
@@ -429,16 +492,16 @@
         return self.a.allColumns() + self.b.allColumns()
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         stmt = SQLFragment()
-        result = self._subexpression(self.a, placeholder, quote, allTables)
+        result = self._subexpression(self.a, metadata, allTables)
         if isinstance(self.a, CompoundComparison) and self.a.op == 'or' and self.op == 'and':
             result = _inParens(result)
         stmt.append(result)
 
         stmt.text += ' %s ' % (self.op,)
 
-        result = self._subexpression(self.b, placeholder, quote, allTables)
+        result = self._subexpression(self.b, metadata, allTables)
         if isinstance(self.b, CompoundComparison) and self.b.op == 'or' and self.op == 'and':
             result = _inParens(result)
         stmt.append(result)
@@ -456,8 +519,8 @@
 
 class _AllColumns(object):
 
-    def subSQL(self, placeholder, quote, allTables):
-        return SQLFragment(quote('*'))
+    def subSQL(self, metadata, allTables):
+        return SQLFragment('*')
 
 ALL_COLUMNS = _AllColumns()
 
@@ -469,7 +532,7 @@
         self.columns = columns
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         first = True
         cstatement = SQLFragment()
         for column in self.columns:
@@ -477,7 +540,7 @@
                 first = False
             else:
                 cstatement.append(SQLFragment(", "))
-            cstatement.append(column.subSQL(placeholder, quote, allTables))
+            cstatement.append(column.subSQL(metadata, allTables))
         return cstatement
 
 
@@ -499,8 +562,8 @@
         self.columns = columns
 
 
-    def subSQL(self, placeholder, quote, allTables):
-        return _inParens(_commaJoined(c.subSQL(placeholder, quote, allTables)
+    def subSQL(self, metadata, allTables):
+        return _inParens(_commaJoined(c.subSQL(metadata, allTables)
                                       for c in self.columns))
 
 
@@ -550,45 +613,45 @@
         return CompoundComparison(other, '=', self)
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         """
         @return: a 'select' statement with placeholders and arguments
 
         @rtype: L{SQLFragment}
         """
-        stmt = SQLFragment(quote("select "))
+        stmt = SQLFragment("select ")
         if self.Distinct:
             stmt.text += "distinct "
         allTables = self.From.tables()
-        stmt.append(self.columns.subSQL(placeholder, quote, allTables))
-        stmt.text += quote(" from ")
-        stmt.append(self.From.subSQL(placeholder, quote, allTables))
+        stmt.append(self.columns.subSQL(metadata, allTables))
+        stmt.text += " from "
+        stmt.append(self.From.subSQL(metadata, allTables))
         if self.Where is not None:
-            wherestmt = self.Where.subSQL(placeholder, quote, allTables)
-            stmt.text += quote(" where ")
+            wherestmt = self.Where.subSQL(metadata, allTables)
+            stmt.text += " where "
             stmt.append(wherestmt)
         if self.GroupBy is not None:
-            stmt.text += quote(" group by ")
+            stmt.text += " group by "
             fst = True
             for subthing in self.GroupBy:
                 if fst:
                     fst = False
                 else:
                     stmt.text += ', '
-                stmt.append(subthing.subSQL(placeholder, quote, allTables))
+                stmt.append(subthing.subSQL(metadata, allTables))
         if self.Having is not None:
-            havingstmt = self.Having.subSQL(placeholder, quote, allTables)
-            stmt.text += quote(" having ")
+            havingstmt = self.Having.subSQL(metadata, allTables)
+            stmt.text += " having "
             stmt.append(havingstmt)
         if self.OrderBy is not None:
-            stmt.text += quote(" order by ")
+            stmt.text += " order by "
             fst = True
             for subthing in self.OrderBy:
                 if fst:
                     fst = False
                 else:
                     stmt.text += ', '
-                stmt.append(subthing.subSQL(placeholder, quote, allTables))
+                stmt.append(subthing.subSQL(metadata, allTables))
             if self.Ascending is not None:
                 if self.Ascending:
                     kw = " asc"
@@ -596,17 +659,16 @@
                     kw = " desc"
                 stmt.append(SQLFragment(kw))
         if self.ForUpdate:
-            stmt.text += quote(" for update")
+            stmt.text += " for update"
         if self.Limit is not None:
-            stmt.text += quote(" limit ")
-            stmt.append(Constant(self.Limit).subSQL(placeholder, quote,
-                                                    allTables))
+            stmt.text += " limit "
+            stmt.append(Constant(self.Limit).subSQL(metadata, allTables))
         return stmt
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         result = SQLFragment("(")
-        result.append(self.toSQL(placeholder, quote))
+        result.append(self.toSQL(metadata))
         result.append(SQLFragment(")"))
         return result
 
@@ -659,8 +721,8 @@
         self.subfragments = subfragments
 
 
-    def subSQL(self, placeholder, quote, allTables):
-        return _commaJoined(f.subSQL(placeholder, quote, allTables)
+    def subSQL(self, metadata, allTables):
+        return _commaJoined(f.subSQL(metadata, allTables)
                             for f in self.subfragments)
 
 
@@ -686,7 +748,7 @@
                     (', '.join([c.name for c in unspecified])))
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         """
         @return: a 'insert' statement with placeholders and arguments
 
@@ -698,18 +760,18 @@
         stmt = SQLFragment('insert into ')
         stmt.append(
             TableSyntax(sortedColumns[0][0].model.table)
-            .subSQL(placeholder, quote, allTables))
+            .subSQL(metadata, allTables))
         stmt.append(SQLFragment(" "))
         stmt.append(_inParens(_commaJoined(
-            [c.subSQL(placeholder, quote, allTables) for (c, v) in
+            [c.subSQL(metadata, allTables) for (c, v) in
              sortedColumns])))
         stmt.append(SQLFragment(" values "))
         stmt.append(_inParens(_commaJoined(
-            [_convert(v).subSQL(placeholder, quote, allTables)
+            [_convert(v).subSQL(metadata, allTables)
              for (c, v) in sortedColumns])))
         if self.Return is not None:
             stmt.text += ' returning '
-            stmt.append(self.Return.subSQL(placeholder, quote, allTables))
+            stmt.append(self.Return.subSQL(metadata, allTables))
         return stmt
 
 
@@ -741,7 +803,7 @@
         self.Return = Return
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         """
         @return: a 'insert' statement with placeholders and arguments
 
@@ -753,22 +815,22 @@
         result = SQLFragment('update ')
         result.append(
             TableSyntax(sortedColumns[0][0].model.table).subSQL(
-                placeholder, quote, allTables)
+                metadata, allTables)
         )
         result.text += ' set '
         result.append(
             _commaJoined(
-                [c.subSQL(placeholder, quote, allTables).append(
-                    SQLFragment(" = ").subSQL(placeholder, quote, allTables)
-                ).append(_convert(v).subSQL(placeholder, quote, allTables))
+                [c.subSQL(metadata, allTables).append(
+                    SQLFragment(" = ").subSQL(metadata, allTables)
+                ).append(_convert(v).subSQL(metadata, allTables))
                     for (c, v) in sortedColumns]
             )
         )
         result.append(SQLFragment( ' where '))
-        result.append(self.Where.subSQL(placeholder, quote, allTables))
+        result.append(self.Where.subSQL(metadata, allTables))
         if self.Return is not None:
             result.append(SQLFragment(' returning '))
-            result.append(self.Return.subSQL(placeholder, quote, allTables))
+            result.append(self.Return.subSQL(metadata, allTables))
         return result
 
 
@@ -785,21 +847,21 @@
         self.Using = Using
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         result = SQLFragment()
         allTables = self.From.tables()
         if self.Using is not None:
             allTables += self.Using.tables()
-        result.text += quote('delete from ')
-        result.append(self.From.subSQL(placeholder, quote, allTables))
+        result.text += 'delete from '
+        result.append(self.From.subSQL(metadata, allTables))
         if self.Using is not None:
             result.text += ' using '
-            result.append(self.Using.subSQL(placeholder, quote, allTables))
-        result.text += quote(' where ')
-        result.append(self.Where.subSQL(placeholder, quote, allTables))
+            result.append(self.Using.subSQL(metadata, allTables))
+        result.text += ' where '
+        result.append(self.Where.subSQL(metadata, allTables))
         if self.Return is not None:
             result.append(SQLFragment(' returning '))
-            result.append(self.Return.subSQL(placeholder, quote, allTables))
+            result.append(self.Return.subSQL(metadata, allTables))
         return result
 
 
@@ -819,11 +881,13 @@
         return cls(table, 'exclusive')
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         return SQLFragment('lock table ').append(
-            self.table.subSQL(placeholder, quote, [self.table])).append(
+            self.table.subSQL(metadata, [self.table])).append(
             SQLFragment(' in %s mode' % (self.mode,)))
 
+
+
 class Savepoint(_Statement):
     """
     An SQL 'savepoint' statement.
@@ -833,7 +897,7 @@
         self.name = name
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         return SQLFragment('savepoint %s' % (self.name,))
 
 
@@ -846,7 +910,7 @@
         self.name = name
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         return SQLFragment('rollback to savepoint %s' % (self.name,))
 
 
@@ -859,24 +923,30 @@
         self.name = name
 
 
-    def toSQL(self, placeholder="?", quote=lambda x: x):
+    def _toSQL(self, metadata):
         return SQLFragment('release savepoint %s' % (self.name,))
 
 
+
 class SavepointAction(object):
-    
+
     def __init__(self, name):
         self._name = name
-    
+
+
     def acquire(self, txn):
         return Savepoint(self._name).on(txn)
 
+
     def rollback(self, txn):
         return RollbackToSavepoint(self._name).on(txn)
 
+
     def release(self, txn):
         return ReleaseSavepoint(self._name).on(txn)
 
+
+
 class SQLFragment(object):
     """
     Combination of SQL text and arguments; a statement which may be executed
@@ -922,7 +992,7 @@
         return self.__class__.__name__ + repr((self.text, self.parameters))
 
 
-    def subSQL(self, placeholder, quote, allTables):
+    def subSQL(self, metadata, allTables):
         return self
 
 

Modified: CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/test/test_sqlsyntax.py
===================================================================
--- CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/test/test_sqlsyntax.py	2011-03-08 03:01:43 UTC (rev 7124)
+++ CalendarServer/branches/users/glyph/oracle/twext/enterprise/dal/test/test_sqlsyntax.py	2011-03-08 03:01:55 UTC (rev 7125)
@@ -26,8 +26,24 @@
 , Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction)
 
 from twext.enterprise.dal.syntax import FunctionInvocation
+
+from twext.enterprise.dal.syntax import FixedPlaceholder
+from twext.enterprise.ienterprise import POSTGRES_DIALECT
 from twisted.trial.unittest import TestCase
 
+
+
+class _FakeTransaction(object):
+    """
+    An L{IAsyncTransaction} that provides the relevant metadata for SQL
+    generation.
+    """
+
+    def __init__(self, paramstyle):
+        self.paramstyle = 'qmark'
+
+
+
 class GenerationTests(TestCase):
     """
     Tests for syntactic helpers to generate SQL queries.
@@ -65,16 +81,16 @@
                           SQLFragment("select * from FOO where BAR = ?", [1]))
 
 
-    def test_quotingAndPlaceholder(self):
+    def test_alternateMetadata(self):
         """
         L{Select} generates a 'select' statement with the specified placeholder
-        syntax and quoting function.
+        syntax when explicitly given L{ConnectionMetadata} which specifies a
+        placeholder.
         """
         self.assertEquals(Select(From=self.schema.FOO,
                                  Where=self.schema.FOO.BAR == 1).toSQL(
-                                 placeholder="*",
-                                 quote=lambda partial: partial.replace("*", "**")),
-                          SQLFragment("select ** from FOO where BAR = *", [1]))
+                                 FixedPlaceholder(POSTGRES_DIALECT, "$$")),
+                          SQLFragment("select * from FOO where BAR = $$", [1]))
 
 
     def test_columnComparison(self):
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20110307/61461b4e/attachment-0001.html>


More information about the calendarserver-changes mailing list