[CalendarServer-changes] [8687] CalendarServer/trunk/twext/enterprise/dal

source_changes at macosforge.org source_changes at macosforge.org
Wed Feb 15 11:43:08 PST 2012


Revision: 8687
          http://trac.macosforge.org/projects/calendarserver/changeset/8687
Author:   cdaboo at apple.com
Date:     2012-02-15 11:43:06 -0800 (Wed, 15 Feb 2012)
Log Message:
-----------
DAL syntax extended to handle set operations (e.g. UNION), sub-select in From, column aliasing in sub-select.

Modified Paths:
--------------
    CalendarServer/trunk/twext/enterprise/dal/syntax.py
    CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py

Modified: CalendarServer/trunk/twext/enterprise/dal/syntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/syntax.py	2012-02-15 19:11:05 UTC (rev 8686)
+++ CalendarServer/trunk/twext/enterprise/dal/syntax.py	2012-02-15 19:43:06 UTC (rev 8687)
@@ -38,11 +38,18 @@
 except ImportError:
     cx_Oracle = None
 
+class DALError(Exception):
+    """
+    Base class for exceptions raised by this module. This can be raised directly for
+    API violations. This exception represents a serious programming error and should
+    normally never be caught or ignored.
+    """
+
 class ConnectionMetadata(object):
     """
     Representation of the metadata about the database connection required to
     generate some SQL, for a single statement.  Contains information necessary
-    to generate placeholder strings and determine the database dialect.
+    to generate place holder strings and determine the database dialect.
     """
 
     def __init__(self, dialect):
@@ -56,7 +63,7 @@
 
 class FixedPlaceholder(ConnectionMetadata):
     """
-    Metadata about a connection which uses a fixed string as its placeholder.
+    Metadata about a connection which uses a fixed string as its place holder.
     """
 
     def __init__(self, dialect, placeholder):
@@ -96,7 +103,7 @@
 
 
 
-class NotEnoughValues(ValueError):
+class NotEnoughValues(DALError):
     """
     Not enough values were supplied for an L{Insert}.
     """
@@ -250,7 +257,7 @@
     def __init__(self, model):
         if not isinstance(model, self.modelType):
             # make sure we don't get a misleading repr()
-            raise ValueError("type mismatch: %r %r", type(self), model)
+            raise DALError("type mismatch: %r %r", type(self), model)
         self.model = model
 
 
@@ -289,7 +296,7 @@
 
 
     def __nonzero__(self):
-        raise ValueError(
+        raise DALError(
             "SQL expressions should not be tested for truth value in Python.")
 
 
@@ -353,7 +360,7 @@
 
 class NamedValue(ExpressionSyntax):
     """
-    A constant within the database; something pre-defined, such as
+    A constant within the database; something predefined, such as
     CURRENT_TIMESTAMP.
     """
     def __init__(self, name):
@@ -702,7 +709,37 @@
         return self.model.table.name + '.' + name
 
 
+class ResultAliasSyntax(ExpressionSyntax):
+    
+    def __init__(self, expression, alias):
+        self.expression = expression
+        self.alias = alias
 
+    def columnReference(self):
+        return AliasReferenceSyntax(self)
+
+    def allColumns(self):
+        return self.expression.allColumns()
+
+    def subSQL(self, metadata, allTables):
+        result = SQLFragment()
+        result.append(self.expression.subSQL(metadata, allTables))
+        result.append(SQLFragment(" %s" % (self.alias,)))
+        return result
+
+
+class AliasReferenceSyntax(ExpressionSyntax):
+    
+    def __init__(self, resultAlias):
+        self.resultAlias = resultAlias
+
+    def allColumns(self):
+        return self.resultAlias.allColumns()
+
+    def subSQL(self, metadata, allTables):
+        return SQLFragment(self.resultAlias.alias)
+
+
 class AliasedColumnSyntax(ColumnSyntax):
     """
     An L{AliasedColumnSyntax} is like a L{ColumnSyntax}, but it generates SQL
@@ -878,7 +915,74 @@
         return self.columns
 
 
+class SetExpression(object):
+    """
+    A UNION, INTERSECT, or EXCEPT construct used inside a SELECT.
+    """
+    
+    OPTYPE_ALL = "all"
+    OPTYPE_DISTINCT = "distinct"
 
+    def __init__(self, selects, optype=None):
+        """
+        
+        @param selects: a single Select or a list of Selects
+        @type selects: C{list} or L{Select}
+        @param optype: whether to use the ALL, DISTINCT constructs: C{None} use neither, OPTYPE_ALL, or OPTYPE_DISTINCT
+        @type optype: C{str}
+        """
+        
+        if isinstance(selects, Select):
+            selects = (selects,)
+        self.selects = selects
+        self.optype = optype
+        
+        for select in self.selects:
+            if not isinstance(select, Select):
+                raise DALError("Must have SELECT statements in a set expression")
+        if self.optype not in (None, SetExpression.OPTYPE_ALL, SetExpression.OPTYPE_DISTINCT,):
+            raise DALError("Must have either 'all' or 'distinct' in a set expression")
+
+    def subSQL(self, metadata, allTables):
+        result = SQLFragment()
+        for select in self.selects:
+            result.append(self.setOpSQL(metadata))
+            if self.optype == SetExpression.OPTYPE_ALL:
+                result.append(SQLFragment("ALL "))
+            elif self.optype == SetExpression.OPTYPE_DISTINCT:
+                result.append(SQLFragment("DISTINCT "))
+            result.append(select.subSQL(metadata, allTables))
+        return result
+
+    def allColumns(self):
+        return []
+
+class Union(SetExpression):
+    """
+    A UNION construct used inside a SELECT.
+    """
+    def setOpSQL(self, metadata):
+        return SQLFragment(" UNION ")
+
+class Intersect(SetExpression):
+    """
+    An INTERSECT construct used inside a SELECT.
+    """
+    def setOpSQL(self, metadata):
+        return SQLFragment(" INTERSECT ")
+
+class Except(SetExpression):
+    """
+    An EXCEPT construct used inside a SELECT.
+    """
+    def setOpSQL(self, metadata):
+        if metadata.dialect == POSTGRES_DIALECT:
+            return SQLFragment(" EXCEPT ")
+        elif metadata.dialect == ORACLE_DIALECT:
+            return SQLFragment(" MINUS ")
+        else:
+            raise NotImplementedError("Unsupported dialect")
+
 class Select(_Statement):
     """
     'select' statement.
@@ -886,7 +990,8 @@
 
     def __init__(self, columns=None, Where=None, From=None, OrderBy=None,
                  GroupBy=None, Limit=None, ForUpdate=False, NoWait=False, Ascending=None,
-                 Having=None, Distinct=False):
+                 Having=None, Distinct=False, As=None,
+                 SetExpression=None):
         self.From = From
         self.Where = Where
         self.Distinct = Distinct
@@ -898,18 +1003,25 @@
         self.GroupBy = GroupBy
         self.Limit = Limit
         self.Having = Having
+        self.SetExpression = SetExpression
+
         if columns is None:
             columns = ALL_COLUMNS
         else:
             if not _columnsMatchTables(columns, From.tables()):
                 raise TableMismatch()
-
             columns = _SomeColumns(columns)
         self.columns = columns
+        
         self.ForUpdate = ForUpdate
         self.NoWait = NoWait
         self.Ascending = Ascending
+        self.As = As
 
+        # A FROM that uses a sub-select will need the AS alias name
+        if isinstance(self.From, Select):
+            if self.From.As is None:
+                self.From.As = ""
 
     def __eq__(self, other):
         """
@@ -926,7 +1038,11 @@
 
         @rtype: L{SQLFragment}
         """
-        stmt = SQLFragment("select ")
+        if self.SetExpression is not None:
+            stmt = SQLFragment("(")
+        else:
+            stmt = SQLFragment()
+        stmt.append(SQLFragment("select "))
         if self.Distinct:
             stmt.text += "distinct "
         allTables = self.From.tables()
@@ -950,6 +1066,9 @@
             havingstmt = self.Having.subSQL(metadata, allTables)
             stmt.text += " having "
             stmt.append(havingstmt)
+        if self.SetExpression is not None:
+            stmt.append(SQLFragment(")"))
+            stmt.append(self.SetExpression.subSQL(metadata, allTables))
         if self.OrderBy is not None:
             stmt.text += " order by "
             fst = True
@@ -986,6 +1105,13 @@
         result = SQLFragment("(")
         result.append(self.toSQL(metadata))
         result.append(SQLFragment(")"))
+        if self.As is not None:
+            if self.As == "":
+                if not hasattr(metadata, "generated_table_aliases"):
+                    metadata.generated_table_aliases = 1
+                self.As = "alias_%d" % (metadata.generated_table_aliases,)
+                metadata.generated_table_aliases += 1
+            result.append(SQLFragment(" %s" % (self.As,)))
         return result
 
 
@@ -1007,6 +1133,22 @@
             for column in self.columns.columns:
                 yield column
 
+    def tables(self):
+        """
+        Determine the tables used by the result columns.
+        """
+        if self.columns is ALL_COLUMNS:
+            # TODO: Possibly this rewriting should always be done, before even
+            # executing the query, so that if we develop a schema mismatch with
+            # the database (additional columns), the application will still see
+            # the right rows.
+            return self.From.tables()
+        else:
+            tables = set([column.model.table for column in self.columns.columns if isinstance(column, ColumnSyntax)])
+            for table in self.From.tables():
+                tables.add(table.model)
+            return [TableSyntax(table) for table in tables]
+        
 
 def _commaJoined(stmts):
     first = True
@@ -1094,7 +1236,7 @@
                 stmt.text += ' into '
                 params = []
                 retvals = self._returnAsList()
-                for n, v in enumerate(retvals):
+                for n, _ignore_v in enumerate(retvals):
                     params.append(
                         Constant(Parameter("oracle_out_" + str(n)))
                         .subSQL(metadata, allTables)
@@ -1124,7 +1266,7 @@
     def _extraResult(self, result, outvars, metadata):
         if metadata.dialect == ORACLE_DIALECT and self.Return is not None:
             def processIt(shouldBeNone):
-                result = [[v.value for k, v in outvars]]
+                result = [[v.value for _ignore_k, v in outvars]]
                 return result
             return result.addCallback(processIt)
         else:
@@ -1206,7 +1348,7 @@
         stmt.append(TableSyntax(tableModel).subSQL(metadata, allTables))
         stmt.append(SQLFragment(" "))
         stmt.append(_inParens(_commaJoined(
-            [c.subSQL(metadata, allTables) for (c, v) in
+            [c.subSQL(metadata, allTables) for (c, _ignore_v) in
              sortedColumns])))
         stmt.append(SQLFragment(" values "))
         stmt.append(_inParens(_commaJoined(

Modified: CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py	2012-02-15 19:11:05 UTC (rev 8686)
+++ CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py	2012-02-15 19:43:06 UTC (rev 8687)
@@ -1,5 +1,5 @@
 ##
-# Copyright (c) 2010 Apple Inc. All rights reserved.
+# Copyright (c) 2010-2012 Apple Inc. All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -23,8 +23,9 @@
 from twext.enterprise.dal.syntax import (
     Select, Insert, Update, Delete, Lock, SQLFragment,
     TableMismatch, Parameter, Max, Len, NotEnoughValues,
-    Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction
-)
+    Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction,
+    Union, Intersect, Except, SetExpression, DALError,
+    ResultAliasSyntax)
 
 from twext.enterprise.dal.syntax import Function
 
@@ -153,7 +154,7 @@
         def sampleComparison():
             if self.schema.FOO.BAR > self.schema.FOO.BAZ:
                 return 'comparison should not succeed'
-        self.assertRaises(ValueError, sampleComparison)
+        self.assertRaises(DALError, sampleComparison)
 
 
     def test_compareWithNULL(self):
@@ -485,6 +486,184 @@
                          [173, 7]))
 
 
+    def test_setSelects(self):
+        """
+        L{SetExpression} produces set operation on selects.
+        """
+        
+        # Simple UNION
+        self.assertEquals(
+            Select(
+                From=self.schema.FOO,
+                Where=(self.schema.FOO.BAR == 1),
+                SetExpression= Union(
+                    Select(
+                        From=self.schema.FOO,
+                        Where=(self.schema.FOO.BAR == 2),
+                    ),
+                ),
+            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            SQLFragment(
+                "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?)", [1, 2]))
+
+        # Simple INTERSECT ALL
+        self.assertEquals(
+            Select(
+                From=self.schema.FOO,
+                Where=(self.schema.FOO.BAR == 1),
+                SetExpression=Intersect(
+                    Select(
+                        From=self.schema.FOO,
+                        Where=(self.schema.FOO.BAR == 2),
+                    ),
+                    optype=SetExpression.OPTYPE_ALL
+                ),
+            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            SQLFragment(
+                "(select * from FOO where BAR = ?) INTERSECT ALL (select * from FOO where BAR = ?)", [1, 2]))
+
+        # Multiple EXCEPTs, not nested, Postgres dialect
+        self.assertEquals(
+            Select(
+                From=self.schema.FOO,
+                SetExpression=Except(
+                    (
+                        Select(
+                            From=self.schema.FOO,
+                            Where=(self.schema.FOO.BAR == 2),
+                        ),
+                        Select(
+                            From=self.schema.FOO,
+                            Where=(self.schema.FOO.BAR == 3),
+                        ),
+                    ),
+                    optype=SetExpression.OPTYPE_DISTINCT,
+                ),
+            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            SQLFragment(
+                "(select * from FOO) EXCEPT DISTINCT (select * from FOO where BAR = ?) EXCEPT DISTINCT (select * from FOO where BAR = ?)", [2, 3]))
+
+        # Nested EXCEPTs, Oracle dialect
+        self.assertEquals(
+            Select(
+                From=self.schema.FOO,
+                SetExpression=Except(
+                    Select(
+                        From=self.schema.FOO,
+                        Where=(self.schema.FOO.BAR == 2),
+                        SetExpression=Except(
+                            Select(
+                                From=self.schema.FOO,
+                                Where=(self.schema.FOO.BAR == 3),
+                            ),
+                        ),
+                    ),
+                ),
+            ).toSQL(FixedPlaceholder(ORACLE_DIALECT, "?")),
+            SQLFragment(
+                "(select * from FOO) MINUS ((select * from FOO where BAR = ?) MINUS (select * from FOO where BAR = ?))", [2, 3]))
+
+        # UNION with order by
+        self.assertEquals(
+            Select(
+                From=self.schema.FOO,
+                Where=(self.schema.FOO.BAR == 1),
+                SetExpression= Union(
+                    Select(
+                        From=self.schema.FOO,
+                        Where=(self.schema.FOO.BAR == 2),
+                    ),
+                ),
+                OrderBy=self.schema.FOO.BAR,
+            ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+            SQLFragment(
+                "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?) order by BAR", [1, 2]))
+
+
+    def test_simpleSubSelects(self):
+        """
+        L{Max}C{(column)} produces an object in the 'columns' clause that
+        renders the 'max' aggregate in SQL.
+        """
+        self.assertEquals(
+            Select(
+                [Max(self.schema.BOZ.QUX)],
+                From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ))
+            ).toSQL(),
+            SQLFragment(
+                "select max(QUX) from (select QUX from BOZ) alias_1"))
+
+        self.assertEquals(
+            Select(
+                [Max(self.schema.BOZ.QUX)],
+                From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ, As="alias_BAR")),
+            ).toSQL(),
+            SQLFragment(
+                "select max(QUX) from (select QUX from BOZ) alias_BAR"))
+
+
+    def test_setSubSelects(self):
+        """
+        L{SetExpression} in a From sub-select.
+        """
+        
+        # Simple UNION
+        self.assertEquals(
+            Select(
+                [Max(self.schema.FOO.BAR)],
+                From=Select(
+                    [self.schema.FOO.BAR],
+                    From=self.schema.FOO,
+                    Where=(self.schema.FOO.BAR == 1),
+                    SetExpression= Union(
+                        Select(
+                            [self.schema.FOO.BAR],
+                            From=self.schema.FOO,
+                            Where=(self.schema.FOO.BAR == 2),
+                        ),
+                    ),
+                )
+            ).toSQL(),
+            SQLFragment(
+                "select max(BAR) from ((select BAR from FOO where BAR = ?) UNION (select BAR from FOO where BAR = ?)) alias_1", [1, 2]))
+
+    def test_selectColumnAliases(self):
+        """
+        L{Select} works with aliased columns.
+        """
+        self.assertEquals(
+            Select(
+                [ResultAliasSyntax(self.schema.BOZ.QUX, "BOZ_QUX")],
+                From=self.schema.BOZ
+            ).toSQL(),
+            SQLFragment("select QUX BOZ_QUX from BOZ"))
+
+        self.assertEquals(
+            Select(
+                [ResultAliasSyntax(Max(self.schema.BOZ.QUX), "MAX_QUX")],
+                From=self.schema.BOZ
+            ).toSQL(),
+            SQLFragment("select max(QUX) MAX_QUX from BOZ"))
+
+        alias = ResultAliasSyntax(Max(self.schema.BOZ.QUX), "MAX_QUX")
+        self.assertEquals(
+            Select([alias.columnReference()],
+                From=Select(
+                    [alias],
+                    From=self.schema.BOZ)
+            ).toSQL(),
+            SQLFragment("select MAX_QUX from (select max(QUX) MAX_QUX from BOZ) alias_1"))
+
+        alias = ResultAliasSyntax(Len(self.schema.BOZ.QUX), "LEN_QUX")
+        self.assertEquals(
+            Select([alias.columnReference()],
+                From=Select(
+                    [alias],
+                    From=self.schema.BOZ)
+            ).toSQL(),
+            SQLFragment("select LEN_QUX from (select character_length(QUX) LEN_QUX from BOZ) alias_1"))
+
+
     def test_inSubSelect(self):
         """
         L{ColumnSyntax.In} returns a sub-expression using the SQL 'in' syntax.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20120215/f2e93ddb/attachment-0001.html>


More information about the calendarserver-changes mailing list