[CalendarServer-changes] [13367] twext/trunk

source_changes at macosforge.org source_changes at macosforge.org
Thu Apr 24 09:10:35 PDT 2014


Revision: 13367
          http://trac.calendarserver.org//changeset/13367
Author:   cdaboo at apple.com
Date:     2014-04-24 09:10:35 -0700 (Thu, 24 Apr 2014)
Log Message:
-----------
Use latest sqlparse which supports Oracle syntax. Fix parsing and schema compare to work with more
complex postgres and Oracle behaviors.

Modified Paths:
--------------
    twext/trunk/setup.py
    twext/trunk/twext/enterprise/dal/model.py
    twext/trunk/twext/enterprise/dal/parseschema.py
    twext/trunk/twext/enterprise/dal/test/test_parseschema.py

Modified: twext/trunk/setup.py
===================================================================
--- twext/trunk/setup.py	2014-04-24 04:27:28 UTC (rev 13366)
+++ twext/trunk/setup.py	2014-04-24 16:10:35 UTC (rev 13367)
@@ -38,6 +38,7 @@
     return modules + setuptools_find_packages()
 
 
+
 def version():
     """
     Compute the version number.
@@ -142,7 +143,7 @@
 
 extras_requirements = {
     # Database Abstraction Layer
-    "DAL": ["sqlparse==0.1.2"],
+    "DAL": ["sqlparse>=0.1.11"],
 
     # LDAP
     "LDAP": ["python-ldap"],
@@ -177,6 +178,7 @@
         pass
 
 
+
 #
 # Run setup
 #

Modified: twext/trunk/twext/enterprise/dal/model.py
===================================================================
--- twext/trunk/twext/enterprise/dal/model.py	2014-04-24 04:27:28 UTC (rev 13366)
+++ twext/trunk/twext/enterprise/dal/model.py	2014-04-24 16:10:35 UTC (rev 13367)
@@ -59,11 +59,12 @@
 
     def __eq__(self, other):
         """
-        Compare equal to other L{SQLTypes} with matching name and length.
+        Compare equal to other L{SQLTypes} with matching name and length. The name is
+        normalized so we can compare schema from different types of DB implementations.
         """
         if not isinstance(other, SQLType):
             return NotImplemented
-        return (self.name, self.length) == (other.name, other.length)
+        return (self.normalizedName(), self.length) == (other.normalizedName(), other.length)
 
 
     def __ne__(self, other):
@@ -87,7 +88,21 @@
         return "<SQL Type: %r%s>" % (self.name, lendesc)
 
 
+    def normalizedName(self):
+        """
+        Map type names to standard names.
+        """
+        return {
+            "nchar": "char",
+            "varchar2": "varchar",
+            "nvarchar2": "varchar",
+            "clob": "text",
+            "nclob": "text",
+            "boolean": "integer",
+        }.get(self.name, self.name)
 
+
+
 class Constraint(object):
     """
     A constraint on a set of columns.
@@ -131,11 +146,13 @@
 
 
 
-class ProcedureCall(object):
+class ProcedureCall(FancyEqMixin):
     """
     An invocation of a stored procedure or built-in function.
     """
 
+    compareAttributes = 'name args'.split()
+
     def __init__(self, name, args):
         _checkstr(name)
         self.name = name
@@ -161,6 +178,19 @@
 
 
 
+def listIfNone(x):
+    return [] if x is None else x
+
+
+
+def stringIfNone(x, attr=None):
+    if attr:
+        return "" if x is None else getattr(x, attr)
+    else:
+        return "" if x is None else x
+
+
+
 class Column(FancyEqMixin, object):
     """
     A column from a table.
@@ -215,14 +245,23 @@
 
         results = []
 
-        # TODO: sql_dump does not do types write now - so ignore this
-        # if self.type != other.type:
-        #     results.append(
-        #         "Table: %s, mismatched column type: %s"
-        #         % (self.table.name, self.name)
-        #     )
-
-        # TODO: figure out how to compare default, references and deleteAction
+        if self.name != other.name:
+            results.append("Table: %s, column names %s and %s do not match" % (self.table.name, self.name, other.name,))
+        if self.type != other.type:
+            results.append("Table: %s, column name %s type mismatch" % (self.table.name, self.name,))
+        if self.default != other.default:
+            # Some DBs don't allow sequence as a default
+            if (
+                isinstance(self.default, Sequence) and other.default == NO_DEFAULT or
+                self.default == NO_DEFAULT and isinstance(other.default, Sequence)
+            ):
+                pass
+            else:
+                results.append("Table: %s, column name %s default mismatch" % (self.table.name, self.name,))
+        if stringIfNone(self.references, "name") != stringIfNone(other.references, "name"):
+            results.append("Table: %s, column name %s references mismatch" % (self.table.name, self.name,))
+        if self.deleteAction != other.deleteAction:
+            results.append("Table: %s, column name %s delete action mismatch" % (self.table.name, self.name,))
         return results
 
 
@@ -333,7 +372,18 @@
         for name in set(myColumns.keys()) & set(otherColumns.keys()):
             results.extend(myColumns[name].compare(otherColumns[name]))
 
-        # TODO: figure out how to compare schemaRows
+        if not all([len(a.compare(b)) == 0 for a, b in zip(
+            listIfNone(self.primaryKey),
+            listIfNone(other.primaryKey),
+        )]):
+            results.append("Table: %s, mismatched primary key" % (self.name,))
+
+        for myRow, otherRow in zip(self.schemaRows, other.schemaRows):
+            myRows = dict([(column.name, value) for column, value in myRow.items()])
+            otherRows = dict([(column.name, value) for column, value in otherRow.items()])
+            if myRows != otherRows:
+                results.append("Table: %s, mismatched schema rows: %s" % (self.name, myRows))
+
         return results
 
 
@@ -392,7 +442,7 @@
         self.constraints.append(Check(protoExpression, name))
 
 
-    def insertSchemaRow(self, values):
+    def insertSchemaRow(self, values, columns=None):
         """
         A statically-defined row was inserted as part of the schema itself.
         This is used for tables that want to track static enumerations, for
@@ -405,9 +455,12 @@
 
         @param values: a C{list} of data items, one for each column in this
             table's current list of L{Column}s.
+        @param columns: a C{list} of column names to insert into. If C{None}
+            then use all table columns.
         """
         row = {}
-        for column, value in zip(self.columns, values):
+        columns = self.columns if columns is None else [self.columnNamed(name) for name in columns]
+        for column, value in zip(columns, values):
             row[column] = value
         self.schemaRows.append(row)
 

Modified: twext/trunk/twext/enterprise/dal/parseschema.py
===================================================================
--- twext/trunk/twext/enterprise/dal/parseschema.py	2014-04-24 04:27:28 UTC (rev 13366)
+++ twext/trunk/twext/enterprise/dal/parseschema.py	2014-04-24 16:10:35 UTC (rev 13367)
@@ -209,17 +209,23 @@
                     idx.addColumn(idx.table.columnNamed(columnName))
 
             elif createType == u"FUNCTION":
-                FunctionModel(
-                    schema,
-                    stmt.token_next(2, True).value.encode("utf-8")
-                )
+                parseFunction(schema, stmt)
 
         elif stmt.get_type() == "INSERT":
             insertTokens = iterSignificant(stmt)
             expect(insertTokens, ttype=Keyword.DML, value="INSERT")
             expect(insertTokens, ttype=Keyword, value="INTO")
 
-            tableName = expect(insertTokens, cls=Identifier).get_name()
+            token = insertTokens.next()
+
+            if isinstance(token, Function):
+                [tableName, columnArgs] = iterSignificant(token)
+                tableName = tableName.get_name()
+                columns = namesInParens(columnArgs)
+            else:
+                tableName = token.get_name()
+                columns = None
+
             expect(insertTokens, ttype=Keyword, value="VALUES")
 
             values = expect(insertTokens, cls=Parenthesis)
@@ -238,16 +244,13 @@
                     [ident.ttype](ident.value)
                 )
 
-            schema.tableNamed(tableName).insertSchemaRow(rowData)
+            schema.tableNamed(tableName).insertSchemaRow(rowData, columns=columns)
 
         elif stmt.get_type() == "CREATE OR REPLACE":
             createType = stmt.token_next(1, True).value.upper()
 
             if createType == u"FUNCTION":
-                FunctionModel(
-                    schema,
-                    stmt.token_next(2, True).token_first(True).token_first(True).value.encode("utf-8")
-                )
+                parseFunction(schema, stmt)
 
         else:
             print("unknown type:", stmt.get_type())
@@ -256,6 +259,25 @@
 
 
 
+def parseFunction(schema, stmt):
+    """
+    A FUNCTION may or may not have an argument list, so we need to account for
+    both possibilities.
+    """
+    fn_name = stmt.token_next(2, True)
+    if isinstance(fn_name, Function):
+        [fn_name, _ignore_args] = iterSignificant(fn_name)
+        fn_name = fn_name.get_name()
+    else:
+        fn_name = fn_name.get_name()
+
+    FunctionModel(
+        schema,
+        fn_name.encode("utf-8"),
+    )
+
+
+
 class _ColumnParser(object):
     """
     Stateful parser for the things between commas.
@@ -326,22 +348,6 @@
             return self.parseConstraint(maybeIdent)
 
 
-    def namesInParens(self, parens):
-        parens = iterSignificant(parens)
-        expect(parens, ttype=Punctuation, value="(")
-        idorids = parens.next()
-
-        if isinstance(idorids, Identifier):
-            idnames = [idorids.get_name()]
-        elif isinstance(idorids, IdentifierList):
-            idnames = [x.get_name() for x in idorids.get_identifiers()]
-        else:
-            raise ViolatedExpectation("identifier or list", repr(idorids))
-
-        expect(parens, ttype=Punctuation, value=")")
-        return idnames
-
-
     def readExpression(self, parens):
         """
         Read a given expression from a Parenthesis object.  (This is currently
@@ -369,7 +375,7 @@
             funcName = expect(parens, ttype=Keyword).value.encode("ascii")
             rhs = FunctionSyntax(funcName)(*[
                 ColumnSyntax(self.table.columnNamed(x)) for x in
-                self.namesInParens(expect(parens, cls=Parenthesis))
+                namesInParens(expect(parens, cls=Parenthesis))
             ])
             result = CompoundComparison(lhs, op, rhs)
 
@@ -404,10 +410,10 @@
 
         if constraintType.match(Keyword, "PRIMARY"):
             expect(self, ttype=Keyword, value="KEY")
-            names = self.namesInParens(expect(self, cls=Parenthesis))
+            names = namesInParens(expect(self, cls=Parenthesis))
             self.table.primaryKey = [self.table.columnNamed(n) for n in names]
         elif constraintType.match(Keyword, "UNIQUE"):
-            names = self.namesInParens(expect(self, cls=Parenthesis))
+            names = namesInParens(expect(self, cls=Parenthesis))
             self.table.tableConstraint(Constraint.UNIQUE, names)
         elif constraintType.match(Keyword, "CHECK"):
             self.table.checkConstraint(self.readExpression(self.next()), ident)
@@ -519,7 +525,7 @@
                             defaultValue.referringColumns.append(theColumn)
                         else:
                             defaultValue = ProcedureCall(
-                                thingo.encode("utf-8"), parens
+                                thingo.encode("utf-8"), namesInParens(parens),
                             )
 
                     elif theDefault.ttype == Number.Integer:
@@ -546,6 +552,17 @@
                     elif theDefault.ttype == String.Single:
                         defaultValue = _destringify(theDefault.value)
 
+                    # Oracle format for current timestamp mapped to postgres variant
+                    elif (
+                        theDefault.ttype == Keyword and
+                        theDefault.value.lower() == "current_timestamp"
+                    ):
+                        expect(self, ttype=Keyword, value="at")
+                        expect(self, ttype=None, value="time")
+                        expect(self, ttype=None, value="zone")
+                        expect(self, ttype=String.Single, value="'UTC'")
+                        defaultValue = ProcedureCall("timezone", [u"UTC", u"CURRENT_TIMESTAMP"])
+
                     else:
                         raise RuntimeError(
                             "not sure what to do: default %r"
@@ -635,11 +652,35 @@
         return token.get_name()
     elif token.ttype == Name:
         return token.value
+    elif token.ttype == String.Single:
+        return _destringify(token.value)
+    elif token.ttype == Keyword:
+        return token.value
     else:
         raise ViolatedExpectation("identifier or name", repr(token))
 
 
 
+def namesInParens(parens):
+    parens = iterSignificant(parens)
+    expect(parens, ttype=Punctuation, value="(")
+    idorids = parens.next()
+
+    if isinstance(idorids, Identifier):
+        idnames = [idorids.get_name()]
+    elif isinstance(idorids, IdentifierList):
+        idnames = [nameOrIdentifier(x) for x in idorids.get_identifiers()]
+    elif idorids.ttype == String.Single:
+        idnames = [nameOrIdentifier(idorids)]
+    else:
+        expectSingle(idorids, ttype=Punctuation, value=")")
+        return []
+
+    expect(parens, ttype=Punctuation, value=")")
+    return idnames
+
+
+
 def expectSingle(nextval, ttype=None, value=None, cls=None):
     """
     Expect some properties from retrieved value.

Modified: twext/trunk/twext/enterprise/dal/test/test_parseschema.py
===================================================================
--- twext/trunk/twext/enterprise/dal/test/test_parseschema.py	2014-04-24 04:27:28 UTC (rev 13366)
+++ twext/trunk/twext/enterprise/dal/test/test_parseschema.py	2014-04-24 16:10:35 UTC (rev 13367)
@@ -19,7 +19,7 @@
 and L{twext.enterprise.dal.parseschema}.
 """
 
-from twext.enterprise.dal.model import Schema
+from twext.enterprise.dal.model import Schema, ProcedureCall
 from twext.enterprise.dal.syntax import CompoundComparison, ColumnSyntax
 
 try:
@@ -176,6 +176,29 @@
         self.assertEquals(table.columnNamed("f").default, None)
 
 
+    def test_defaultFunctionColumns(self):
+        """
+        Parsing a 'default' column with a function call in it will return
+        that function as the 'default' attribute of the Column object.
+        """
+        s = self.schemaFromString(
+            """
+            create table a (
+                b1 integer default tz(),
+                b2 integer default tz('UTC'),
+                b3 integer default tz('UTC', 'GMT'),
+                b4 integer default timezone('UTC', CURRENT_TIMESTAMP),
+                b5 integer default CURRENT_TIMESTAMP at time zone 'UTC'
+            );
+            """)
+        table = s.tableNamed("a")
+        self.assertEquals(table.columnNamed("b1").default, ProcedureCall("tz", []))
+        self.assertEquals(table.columnNamed("b2").default, ProcedureCall("tz", ["UTC"]))
+        self.assertEquals(table.columnNamed("b3").default, ProcedureCall("tz", ["UTC", "GMT"]))
+        self.assertEquals(table.columnNamed("b4").default, ProcedureCall("timezone", ["UTC", "CURRENT_TIMESTAMP"]))
+        self.assertEquals(table.columnNamed("b5").default, ProcedureCall("timezone", ["UTC", "CURRENT_TIMESTAMP"]))
+
+
     def test_needsValue(self):
         """
         Columns with defaults, or with a 'not null' constraint don't need a
@@ -333,7 +356,7 @@
             """
             create table a1 (b1 integer primary key);
             create table c2 (d2 integer references a1 on delete cascade);
-            create table e3 (f3 integer references a1 on delete set null);
+            create table ee3 (f3 integer references a1 on delete set null);
             create table g4 (h4 integer references a1 on delete set default);
             """
         )
@@ -346,7 +369,7 @@
             "cascade"
         )
         self.assertEquals(
-            s.tableNamed("e3").columnNamed("f3").deleteAction,
+            s.tableNamed("ee3").columnNamed("f3").deleteAction,
             "set null"
         )
         self.assertEquals(
@@ -388,7 +411,7 @@
             """
             create table q (b integer); -- noise
             create table a (b integer primary key, c integer);
-            create table z (c integer, unique(c) );
+            create table z (c integer, unique (c) );
 
             create unique index idx_a_c on a(c);
             create index idx_a_b_c on a (c, b);
@@ -417,13 +440,47 @@
     RETURN i + 1;
 END;
 $$ LANGUAGE plpgsql;
+CREATE FUNCTION autoincrement RETURNS integer AS $$
+BEGIN
+    RETURN 1;
+END;
+$$ LANGUAGE plpgsql;
 CREATE OR REPLACE FUNCTION decrement(i integer) RETURNS integer AS $$
 BEGIN
     RETURN i - 1;
 END;
 $$ LANGUAGE plpgsql;
+CREATE OR REPLACE FUNCTION autodecrement (i integer) RETURNS integer AS $$
+BEGIN
+    RETURN i - 1;
+END;
+$$ LANGUAGE plpgsql;
             """
         )
         self.assertTrue(s.functionNamed("increment") is not None)
         self.assertTrue(s.functionNamed("decrement") is not None)
         self.assertRaises(KeyError, s.functionNamed, "merge")
+
+
+    def test_insert(self):
+        """
+        An 'insert' statement will add an L{schemaRows} to an L{Table}.
+        """
+        s = self.schemaFromString(
+            """
+            create table alpha (beta integer, gamma integer not null);
+
+            insert into alpha values (1, 2);
+            insert into alpha (gamma, beta) values (3, 4);
+            """
+        )
+        self.assertTrue(s.tableNamed("alpha") is not None)
+        self.assertEqual(len(s.tableNamed("alpha").schemaRows), 2)
+        rows = [[(column.name, value) for column, value in sorted(row.items(), key=lambda x:x[0])] for row in s.tableNamed("alpha").schemaRows]
+        self.assertEqual(
+            rows,
+            [
+                [("beta", 1), ("gamma", 2)],
+                [("beta", 4), ("gamma", 3)],
+            ]
+        )
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20140424/a39e5981/attachment-0001.html>


More information about the calendarserver-changes mailing list