[CalendarServer-changes] [8687] CalendarServer/trunk/twext/enterprise/dal
source_changes at macosforge.org
source_changes at macosforge.org
Wed Feb 15 11:43:08 PST 2012
Revision: 8687
http://trac.macosforge.org/projects/calendarserver/changeset/8687
Author: cdaboo at apple.com
Date: 2012-02-15 11:43:06 -0800 (Wed, 15 Feb 2012)
Log Message:
-----------
DAL syntax extended to handle set operations (e.g. UNION), sub-select in From, column aliasing in sub-select.
Modified Paths:
--------------
CalendarServer/trunk/twext/enterprise/dal/syntax.py
CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py
Modified: CalendarServer/trunk/twext/enterprise/dal/syntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/syntax.py 2012-02-15 19:11:05 UTC (rev 8686)
+++ CalendarServer/trunk/twext/enterprise/dal/syntax.py 2012-02-15 19:43:06 UTC (rev 8687)
@@ -38,11 +38,18 @@
except ImportError:
cx_Oracle = None
+class DALError(Exception):
+ """
+ Base class for exceptions raised by this module. This can be raised directly for
+ API violations. This exception represents a serious programming error and should
+ normally never be caught or ignored.
+ """
+
class ConnectionMetadata(object):
"""
Representation of the metadata about the database connection required to
generate some SQL, for a single statement. Contains information necessary
- to generate placeholder strings and determine the database dialect.
+ to generate place holder strings and determine the database dialect.
"""
def __init__(self, dialect):
@@ -56,7 +63,7 @@
class FixedPlaceholder(ConnectionMetadata):
"""
- Metadata about a connection which uses a fixed string as its placeholder.
+ Metadata about a connection which uses a fixed string as its place holder.
"""
def __init__(self, dialect, placeholder):
@@ -96,7 +103,7 @@
-class NotEnoughValues(ValueError):
+class NotEnoughValues(DALError):
"""
Not enough values were supplied for an L{Insert}.
"""
@@ -250,7 +257,7 @@
def __init__(self, model):
if not isinstance(model, self.modelType):
# make sure we don't get a misleading repr()
- raise ValueError("type mismatch: %r %r", type(self), model)
+ raise DALError("type mismatch: %r %r", type(self), model)
self.model = model
@@ -289,7 +296,7 @@
def __nonzero__(self):
- raise ValueError(
+ raise DALError(
"SQL expressions should not be tested for truth value in Python.")
@@ -353,7 +360,7 @@
class NamedValue(ExpressionSyntax):
"""
- A constant within the database; something pre-defined, such as
+ A constant within the database; something predefined, such as
CURRENT_TIMESTAMP.
"""
def __init__(self, name):
@@ -702,7 +709,37 @@
return self.model.table.name + '.' + name
+class ResultAliasSyntax(ExpressionSyntax):
+
+ def __init__(self, expression, alias):
+ self.expression = expression
+ self.alias = alias
+ def columnReference(self):
+ return AliasReferenceSyntax(self)
+
+ def allColumns(self):
+ return self.expression.allColumns()
+
+ def subSQL(self, metadata, allTables):
+ result = SQLFragment()
+ result.append(self.expression.subSQL(metadata, allTables))
+ result.append(SQLFragment(" %s" % (self.alias,)))
+ return result
+
+
+class AliasReferenceSyntax(ExpressionSyntax):
+
+ def __init__(self, resultAlias):
+ self.resultAlias = resultAlias
+
+ def allColumns(self):
+ return self.resultAlias.allColumns()
+
+ def subSQL(self, metadata, allTables):
+ return SQLFragment(self.resultAlias.alias)
+
+
class AliasedColumnSyntax(ColumnSyntax):
"""
An L{AliasedColumnSyntax} is like a L{ColumnSyntax}, but it generates SQL
@@ -878,7 +915,74 @@
return self.columns
+class SetExpression(object):
+ """
+ A UNION, INTERSECT, or EXCEPT construct used inside a SELECT.
+ """
+
+ OPTYPE_ALL = "all"
+ OPTYPE_DISTINCT = "distinct"
+ def __init__(self, selects, optype=None):
+ """
+
+ @param selects: a single Select or a list of Selects
+ @type selects: C{list} or L{Select}
+ @param optype: whether to use the ALL, DISTINCT constructs: C{None} use neither, OPTYPE_ALL, or OPTYPE_DISTINCT
+ @type optype: C{str}
+ """
+
+ if isinstance(selects, Select):
+ selects = (selects,)
+ self.selects = selects
+ self.optype = optype
+
+ for select in self.selects:
+ if not isinstance(select, Select):
+ raise DALError("Must have SELECT statements in a set expression")
+ if self.optype not in (None, SetExpression.OPTYPE_ALL, SetExpression.OPTYPE_DISTINCT,):
+ raise DALError("Must have either 'all' or 'distinct' in a set expression")
+
+ def subSQL(self, metadata, allTables):
+ result = SQLFragment()
+ for select in self.selects:
+ result.append(self.setOpSQL(metadata))
+ if self.optype == SetExpression.OPTYPE_ALL:
+ result.append(SQLFragment("ALL "))
+ elif self.optype == SetExpression.OPTYPE_DISTINCT:
+ result.append(SQLFragment("DISTINCT "))
+ result.append(select.subSQL(metadata, allTables))
+ return result
+
+ def allColumns(self):
+ return []
+
+class Union(SetExpression):
+ """
+ A UNION construct used inside a SELECT.
+ """
+ def setOpSQL(self, metadata):
+ return SQLFragment(" UNION ")
+
+class Intersect(SetExpression):
+ """
+ An INTERSECT construct used inside a SELECT.
+ """
+ def setOpSQL(self, metadata):
+ return SQLFragment(" INTERSECT ")
+
+class Except(SetExpression):
+ """
+ An EXCEPT construct used inside a SELECT.
+ """
+ def setOpSQL(self, metadata):
+ if metadata.dialect == POSTGRES_DIALECT:
+ return SQLFragment(" EXCEPT ")
+ elif metadata.dialect == ORACLE_DIALECT:
+ return SQLFragment(" MINUS ")
+ else:
+ raise NotImplementedError("Unsupported dialect")
+
class Select(_Statement):
"""
'select' statement.
@@ -886,7 +990,8 @@
def __init__(self, columns=None, Where=None, From=None, OrderBy=None,
GroupBy=None, Limit=None, ForUpdate=False, NoWait=False, Ascending=None,
- Having=None, Distinct=False):
+ Having=None, Distinct=False, As=None,
+ SetExpression=None):
self.From = From
self.Where = Where
self.Distinct = Distinct
@@ -898,18 +1003,25 @@
self.GroupBy = GroupBy
self.Limit = Limit
self.Having = Having
+ self.SetExpression = SetExpression
+
if columns is None:
columns = ALL_COLUMNS
else:
if not _columnsMatchTables(columns, From.tables()):
raise TableMismatch()
-
columns = _SomeColumns(columns)
self.columns = columns
+
self.ForUpdate = ForUpdate
self.NoWait = NoWait
self.Ascending = Ascending
+ self.As = As
+ # A FROM that uses a sub-select will need the AS alias name
+ if isinstance(self.From, Select):
+ if self.From.As is None:
+ self.From.As = ""
def __eq__(self, other):
"""
@@ -926,7 +1038,11 @@
@rtype: L{SQLFragment}
"""
- stmt = SQLFragment("select ")
+ if self.SetExpression is not None:
+ stmt = SQLFragment("(")
+ else:
+ stmt = SQLFragment()
+ stmt.append(SQLFragment("select "))
if self.Distinct:
stmt.text += "distinct "
allTables = self.From.tables()
@@ -950,6 +1066,9 @@
havingstmt = self.Having.subSQL(metadata, allTables)
stmt.text += " having "
stmt.append(havingstmt)
+ if self.SetExpression is not None:
+ stmt.append(SQLFragment(")"))
+ stmt.append(self.SetExpression.subSQL(metadata, allTables))
if self.OrderBy is not None:
stmt.text += " order by "
fst = True
@@ -986,6 +1105,13 @@
result = SQLFragment("(")
result.append(self.toSQL(metadata))
result.append(SQLFragment(")"))
+ if self.As is not None:
+ if self.As == "":
+ if not hasattr(metadata, "generated_table_aliases"):
+ metadata.generated_table_aliases = 1
+ self.As = "alias_%d" % (metadata.generated_table_aliases,)
+ metadata.generated_table_aliases += 1
+ result.append(SQLFragment(" %s" % (self.As,)))
return result
@@ -1007,6 +1133,22 @@
for column in self.columns.columns:
yield column
+ def tables(self):
+ """
+ Determine the tables used by the result columns.
+ """
+ if self.columns is ALL_COLUMNS:
+ # TODO: Possibly this rewriting should always be done, before even
+ # executing the query, so that if we develop a schema mismatch with
+ # the database (additional columns), the application will still see
+ # the right rows.
+ return self.From.tables()
+ else:
+ tables = set([column.model.table for column in self.columns.columns if isinstance(column, ColumnSyntax)])
+ for table in self.From.tables():
+ tables.add(table.model)
+ return [TableSyntax(table) for table in tables]
+
def _commaJoined(stmts):
first = True
@@ -1094,7 +1236,7 @@
stmt.text += ' into '
params = []
retvals = self._returnAsList()
- for n, v in enumerate(retvals):
+ for n, _ignore_v in enumerate(retvals):
params.append(
Constant(Parameter("oracle_out_" + str(n)))
.subSQL(metadata, allTables)
@@ -1124,7 +1266,7 @@
def _extraResult(self, result, outvars, metadata):
if metadata.dialect == ORACLE_DIALECT and self.Return is not None:
def processIt(shouldBeNone):
- result = [[v.value for k, v in outvars]]
+ result = [[v.value for _ignore_k, v in outvars]]
return result
return result.addCallback(processIt)
else:
@@ -1206,7 +1348,7 @@
stmt.append(TableSyntax(tableModel).subSQL(metadata, allTables))
stmt.append(SQLFragment(" "))
stmt.append(_inParens(_commaJoined(
- [c.subSQL(metadata, allTables) for (c, v) in
+ [c.subSQL(metadata, allTables) for (c, _ignore_v) in
sortedColumns])))
stmt.append(SQLFragment(" values "))
stmt.append(_inParens(_commaJoined(
Modified: CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py
===================================================================
--- CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py 2012-02-15 19:11:05 UTC (rev 8686)
+++ CalendarServer/trunk/twext/enterprise/dal/test/test_sqlsyntax.py 2012-02-15 19:43:06 UTC (rev 8687)
@@ -1,5 +1,5 @@
##
-# Copyright (c) 2010 Apple Inc. All rights reserved.
+# Copyright (c) 2010-2012 Apple Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -23,8 +23,9 @@
from twext.enterprise.dal.syntax import (
Select, Insert, Update, Delete, Lock, SQLFragment,
TableMismatch, Parameter, Max, Len, NotEnoughValues,
- Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction
-)
+ Savepoint, RollbackToSavepoint, ReleaseSavepoint, SavepointAction,
+ Union, Intersect, Except, SetExpression, DALError,
+ ResultAliasSyntax)
from twext.enterprise.dal.syntax import Function
@@ -153,7 +154,7 @@
def sampleComparison():
if self.schema.FOO.BAR > self.schema.FOO.BAZ:
return 'comparison should not succeed'
- self.assertRaises(ValueError, sampleComparison)
+ self.assertRaises(DALError, sampleComparison)
def test_compareWithNULL(self):
@@ -485,6 +486,184 @@
[173, 7]))
+ def test_setSelects(self):
+ """
+ L{SetExpression} produces set operation on selects.
+ """
+
+ # Simple UNION
+ self.assertEquals(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 1),
+ SetExpression= Union(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 2),
+ ),
+ ),
+ ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+ SQLFragment(
+ "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?)", [1, 2]))
+
+ # Simple INTERSECT ALL
+ self.assertEquals(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 1),
+ SetExpression=Intersect(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 2),
+ ),
+ optype=SetExpression.OPTYPE_ALL
+ ),
+ ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+ SQLFragment(
+ "(select * from FOO where BAR = ?) INTERSECT ALL (select * from FOO where BAR = ?)", [1, 2]))
+
+ # Multiple EXCEPTs, not nested, Postgres dialect
+ self.assertEquals(
+ Select(
+ From=self.schema.FOO,
+ SetExpression=Except(
+ (
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 2),
+ ),
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 3),
+ ),
+ ),
+ optype=SetExpression.OPTYPE_DISTINCT,
+ ),
+ ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+ SQLFragment(
+ "(select * from FOO) EXCEPT DISTINCT (select * from FOO where BAR = ?) EXCEPT DISTINCT (select * from FOO where BAR = ?)", [2, 3]))
+
+ # Nested EXCEPTs, Oracle dialect
+ self.assertEquals(
+ Select(
+ From=self.schema.FOO,
+ SetExpression=Except(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 2),
+ SetExpression=Except(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 3),
+ ),
+ ),
+ ),
+ ),
+ ).toSQL(FixedPlaceholder(ORACLE_DIALECT, "?")),
+ SQLFragment(
+ "(select * from FOO) MINUS ((select * from FOO where BAR = ?) MINUS (select * from FOO where BAR = ?))", [2, 3]))
+
+ # UNION with order by
+ self.assertEquals(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 1),
+ SetExpression= Union(
+ Select(
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 2),
+ ),
+ ),
+ OrderBy=self.schema.FOO.BAR,
+ ).toSQL(FixedPlaceholder(POSTGRES_DIALECT, "?")),
+ SQLFragment(
+ "(select * from FOO where BAR = ?) UNION (select * from FOO where BAR = ?) order by BAR", [1, 2]))
+
+
+ def test_simpleSubSelects(self):
+ """
+ L{Max}C{(column)} produces an object in the 'columns' clause that
+ renders the 'max' aggregate in SQL.
+ """
+ self.assertEquals(
+ Select(
+ [Max(self.schema.BOZ.QUX)],
+ From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ))
+ ).toSQL(),
+ SQLFragment(
+ "select max(QUX) from (select QUX from BOZ) alias_1"))
+
+ self.assertEquals(
+ Select(
+ [Max(self.schema.BOZ.QUX)],
+ From=(Select([self.schema.BOZ.QUX], From=self.schema.BOZ, As="alias_BAR")),
+ ).toSQL(),
+ SQLFragment(
+ "select max(QUX) from (select QUX from BOZ) alias_BAR"))
+
+
+ def test_setSubSelects(self):
+ """
+ L{SetExpression} in a From sub-select.
+ """
+
+ # Simple UNION
+ self.assertEquals(
+ Select(
+ [Max(self.schema.FOO.BAR)],
+ From=Select(
+ [self.schema.FOO.BAR],
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 1),
+ SetExpression= Union(
+ Select(
+ [self.schema.FOO.BAR],
+ From=self.schema.FOO,
+ Where=(self.schema.FOO.BAR == 2),
+ ),
+ ),
+ )
+ ).toSQL(),
+ SQLFragment(
+ "select max(BAR) from ((select BAR from FOO where BAR = ?) UNION (select BAR from FOO where BAR = ?)) alias_1", [1, 2]))
+
+ def test_selectColumnAliases(self):
+ """
+ L{Select} works with aliased columns.
+ """
+ self.assertEquals(
+ Select(
+ [ResultAliasSyntax(self.schema.BOZ.QUX, "BOZ_QUX")],
+ From=self.schema.BOZ
+ ).toSQL(),
+ SQLFragment("select QUX BOZ_QUX from BOZ"))
+
+ self.assertEquals(
+ Select(
+ [ResultAliasSyntax(Max(self.schema.BOZ.QUX), "MAX_QUX")],
+ From=self.schema.BOZ
+ ).toSQL(),
+ SQLFragment("select max(QUX) MAX_QUX from BOZ"))
+
+ alias = ResultAliasSyntax(Max(self.schema.BOZ.QUX), "MAX_QUX")
+ self.assertEquals(
+ Select([alias.columnReference()],
+ From=Select(
+ [alias],
+ From=self.schema.BOZ)
+ ).toSQL(),
+ SQLFragment("select MAX_QUX from (select max(QUX) MAX_QUX from BOZ) alias_1"))
+
+ alias = ResultAliasSyntax(Len(self.schema.BOZ.QUX), "LEN_QUX")
+ self.assertEquals(
+ Select([alias.columnReference()],
+ From=Select(
+ [alias],
+ From=self.schema.BOZ)
+ ).toSQL(),
+ SQLFragment("select LEN_QUX from (select character_length(QUX) LEN_QUX from BOZ) alias_1"))
+
+
def test_inSubSelect(self):
"""
L{ColumnSyntax.In} returns a sub-expression using the SQL 'in' syntax.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20120215/f2e93ddb/attachment-0001.html>
More information about the calendarserver-changes
mailing list