[CalendarServer-changes] [1672] CalendarServer/trunk/twistedcaldav

source_changes at macosforge.org source_changes at macosforge.org
Fri Jul 13 14:04:58 PDT 2007


Revision: 1672
          http://trac.macosforge.org/projects/calendarserver/changeset/1672
Author:   cdaboo at apple.com
Date:     2007-07-13 14:04:58 -0700 (Fri, 13 Jul 2007)

Log Message:
-----------
Merged branches/users/cdaboo/digestdb-fix-1654 to trunk.

Modified Paths:
--------------
    CalendarServer/trunk/twistedcaldav/directory/digest.py
    CalendarServer/trunk/twistedcaldav/sql.py

Added Paths:
-----------
    CalendarServer/trunk/twistedcaldav/test/test_sql.py

Modified: CalendarServer/trunk/twistedcaldav/directory/digest.py
===================================================================
--- CalendarServer/trunk/twistedcaldav/directory/digest.py	2007-07-13 20:32:22 UTC (rev 1671)
+++ CalendarServer/trunk/twistedcaldav/directory/digest.py	2007-07-13 21:04:58 UTC (rev 1672)
@@ -18,16 +18,22 @@
 
 from twistedcaldav.sql import AbstractSQLDatabase
 
+from twisted.cred import error
+from twisted.python import log
 from twisted.web2.auth.digest import DigestCredentialFactory
+from twisted.web2.auth.digest import DigestedCredentials
 
 from zope.interface import implements, Interface
 
 import cPickle as pickle
-from twisted.cred import error
-from twisted.web2.auth.digest import DigestedCredentials
+import os
 import time
-import os
 
+try:
+    from sqlite3 import OperationalError
+except ImportError:
+    from pysqlite2.dbapi2 import OperationalError
+
 """
 Overrides twisted.web2.auth.digest to allow specifying a qop value as a configuration parameter.
 Also adds an sqlite-based credentials cache that is multi-process safe.
@@ -151,76 +157,117 @@
     dbFilename = ".db.digestcredentialscache"
     dbFormatVersion = "2"
 
+    exceptionLimit = 10
+
     def __init__(self, path):
         db_path = os.path.join(path, DigestCredentialsDB.dbFilename)
         if os.path.exists(db_path):
             os.remove(db_path)
-        super(DigestCredentialsDB, self).__init__(db_path)
-        self.db = {}
+        super(DigestCredentialsDB, self).__init__(db_path, autocommit=True)
+        self.exceptions = 0
     
     def has_key(self, key):
         """
         See IDigestCredentialsDatabase.
         """
-        for ignore_key in self._db_execute(
-            "select KEY from DIGESTCREDENTIALS where KEY = :1",
-            key
-        ):
-            return True
-        else:
-            return False
+        try:
+            for ignore_key in self._db_execute(
+                "select KEY from DIGESTCREDENTIALS where KEY = :1",
+                key
+            ):
+                return True
+            else:
+                return False
+            self.exceptions = 0
+        except OperationalError, e:
+            self.exceptions += 1
+            if self.exceptions >= self.exceptionLimit:
+                self._db_close()
+                log.err("Reset digest credentials database connection: %s" % (e,))
+            raise
 
     def set(self, key, value):
         """
         See IDigestCredentialsDatabase.
         """
-        self._delete_from_db(key)
-        pvalue = pickle.dumps(value)
-        self._add_to_db(key, pvalue)
-        self._db_commit()
+        try:
+            pvalue = pickle.dumps(value)
+            self._set_in_db(key, pvalue)
+            self.exceptions = 0
+        except OperationalError, e:
+            self.exceptions += 1
+            if self.exceptions >= self.exceptionLimit:
+                self._db_close()
+                log.err("Reset digest credentials database connection: %s" % (e,))
+            raise
 
     def get(self, key):
         """
         See IDigestCredentialsDatabase.
         """
-        for pvalue in self._db_execute(
-            "select VALUE from DIGESTCREDENTIALS where KEY = :1",
-            key
-        ):
-            return pickle.loads(str(pvalue[0]))
-        else:
-            return None
+        try:
+            for pvalue in self._db_execute(
+                "select VALUE from DIGESTCREDENTIALS where KEY = :1",
+                key
+            ):
+                self.exceptions = 0
+                return pickle.loads(str(pvalue[0]))
+            else:
+                self.exceptions = 0
+                return None
+        except OperationalError, e:
+            self.exceptions += 1
+            if self.exceptions >= self.exceptionLimit:
+                self._db_close()
+                log.err("Reset digest credentials database connection: %s" % (e,))
+            raise
 
     def delete(self, key):
         """
         See IDigestCredentialsDatabase.
         """
-        self._delete_from_db(key)
-        self._db_commit()
+        try:
+            self._delete_from_db(key)
+            self.exceptions = 0
+        except OperationalError, e:
+            self.exceptions += 1
+            if self.exceptions >= self.exceptionLimit:
+                self._db_close()
+                log.err("Reset digest credentials database connection: %s" % (e,))
+            raise
 
     def keys(self):
         """
         See IDigestCredentialsDatabase.
         """
-        result = []
-        for key in self._db_execute("select KEY from DIGESTCREDENTIALS"):
-            result.append(str(key[0]))
-        
-        return result
+        try:
+            result = []
+            for key in self._db_execute("select KEY from DIGESTCREDENTIALS"):
+                result.append(str(key[0]))
+            
+            self.exceptions = 0
+            return result
+        except OperationalError, e:
+            self.exceptions += 1
+            if self.exceptions >= self.exceptionLimit:
+                self._db_close()
+                log.err("Reset digest credentials database connection: %s" % (e,))
+            raise
 
-    def _add_to_db(self, key, value):
+    def _set_in_db(self, key, value):
         """
-        Insert the specified entry into the database.
+        Insert the specified entry into the database, replacing any that might already exist.
 
         @param key:   the key to add.
         @param value: the value to add.
         """
-        self._db_execute(
+        self._db().execute(
             """
-            insert into DIGESTCREDENTIALS (KEY, VALUE)
+            insert or replace into DIGESTCREDENTIALS (KEY, VALUE)
             values (:1, :2)
-            """, key, value
+            """, (key, value,)
         )
+        self._db_commit()
        
     def _delete_from_db(self, key):
         """
@@ -228,7 +275,8 @@
 
         @param key: the key to remove.
         """
-        self._db_execute("delete from DIGESTCREDENTIALS where KEY = :1", key)
+        self._db().execute("delete from DIGESTCREDENTIALS where KEY = :1", (key,))
+        self._db_commit()
     
     def _db_version(self):
         """
@@ -254,7 +302,7 @@
         q.execute(
             """
             create table DIGESTCREDENTIALS (
-                KEY         text,
+                KEY         text unique,
                 VALUE       text
             )
             """

Modified: CalendarServer/trunk/twistedcaldav/sql.py
===================================================================
--- CalendarServer/trunk/twistedcaldav/sql.py	2007-07-13 20:32:22 UTC (rev 1671)
+++ CalendarServer/trunk/twistedcaldav/sql.py	2007-07-13 21:04:58 UTC (rev 1672)
@@ -38,13 +38,16 @@
     A generic SQL database.
     """
 
-    def __init__(self, dbpath):
+    def __init__(self, dbpath, autocommit=False):
         """
         
         @param dbpath: the path where the db file is stored.
         @type dbpath: str
+        @param autocommit: C{True} if auto-commit mode is desired, C{False} otherwise
+        @type autocommit: bool
         """
         self.dbpath = dbpath
+        self.autocommit = autocommit
 
     def _db_version(self):
         """
@@ -65,7 +68,10 @@
         """
         if not hasattr(self, "_db_connection"):
             db_filename = self.dbpath
-            self._db_connection = sqlite.connect(db_filename)
+            if self.autocommit:
+                self._db_connection = sqlite.connect(db_filename, isolation_level=None)
+            else:
+                self._db_connection = sqlite.connect(db_filename)
 
             #
             # Set up the schema
@@ -179,6 +185,11 @@
         """
         pass
 
+    def _db_close(self):
+        if hasattr(self, "_db_connection"):
+            self._db_connection.close()
+            del self._db_connection
+
     def _db_values_for_sql(self, sql, *query_params):
         """
         Execute an SQL query and obtain the resulting values.

Copied: CalendarServer/trunk/twistedcaldav/test/test_sql.py (from rev 1671, CalendarServer/branches/users/cdaboo/digestdb-fix-1654/twistedcaldav/test/test_sql.py)
===================================================================
--- CalendarServer/trunk/twistedcaldav/test/test_sql.py	                        (rev 0)
+++ CalendarServer/trunk/twistedcaldav/test/test_sql.py	2007-07-13 21:04:58 UTC (rev 1672)
@@ -0,0 +1,139 @@
+##
+# Copyright (c) 2007 Apple Inc. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# DRI: Cyrus Daboo, cdaboo at apple.com
+##
+
+from twistedcaldav.sql import AbstractSQLDatabase
+
+import twistedcaldav.test.util
+
+class SQL (twistedcaldav.test.util.TestCase):
+    """
+    Test abstract SQL DB class
+    """
+    
+    class TestDB(AbstractSQLDatabase):
+        
+        def __init__(self, path, autocommit=False):
+            super(SQL.TestDB, self).__init__(path, autocommit=autocommit)
+
+        def _db_version(self):
+            """
+            @return: the schema version assigned to this index.
+            """
+            return 1
+            
+        def _db_type(self):
+            """
+            @return: the collection type assigned to this index.
+            """
+            return "TESTTYPE"
+            
+        def _db_init_data_tables(self, q):
+            """
+            Initialise the underlying database tables.
+            @param q:           a database cursor to use.
+            """
+    
+            #
+            # TESTTYPE table
+            #
+            q.execute(
+                """
+                create table TESTTYPE (
+                    KEY         text unique,
+                    VALUE       text
+                )
+                """
+            )
+
+    def test_connect(self):
+        """
+        Connect to database and create table
+        """
+        db = SQL.TestDB(self.mktemp())
+        self.assertFalse(hasattr(db, "_db_connection"))
+        self.assertTrue(db._db() is not None)
+        self.assertTrue(db._db_connection is not None)
+
+    def test_connect_autocommit(self):
+        """
+        Connect to database and create table
+        """
+        db = SQL.TestDB(self.mktemp(), autocommit=True)
+        self.assertFalse(hasattr(db, "_db_connection"))
+        self.assertTrue(db._db() is not None)
+        self.assertTrue(db._db_connection is not None)
+
+    def test_readwrite(self):
+        """
+        Add a record, search for it
+        """
+        db = SQL.TestDB(self.mktemp())
+        db._db().execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", ("FOO", "BAR",))
+        db._db_commit()
+        q = db._db().execute("SELECT * from TESTTYPE")
+        items = [i for i in q.fetchall()]
+        self.assertEqual(items, [("FOO", "BAR")])
+
+    def test_readwrite_autocommit(self):
+        """
+        Add a record, search for it
+        """
+        db = SQL.TestDB(self.mktemp(), autocommit=True)
+        db._db().execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", ("FOO", "BAR",))
+        q = db._db().execute("SELECT * from TESTTYPE")
+        items = [i for i in q.fetchall()]
+        self.assertEqual(items, [("FOO", "BAR")])
+
+    def test_readwrite_cursor(self):
+        """
+        Add a record, search for it
+        """
+        db = SQL.TestDB(self.mktemp())
+        db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR")
+        items = db._db_execute("SELECT * from TESTTYPE")
+        self.assertEqual(items, [("FOO", "BAR")])
+
+    def test_readwrite_cursor_autocommit(self):
+        """
+        Add a record, search for it
+        """
+        db = SQL.TestDB(self.mktemp(), autocommit=True)
+        db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR")
+        items = db._db_execute("SELECT * from TESTTYPE")
+        self.assertEqual(items, [("FOO", "BAR")])
+
+    def test_readwrite_rollback(self):
+        """
+        Add a record, search for it
+        """
+        db = SQL.TestDB(self.mktemp())
+        db._db_execute("INSERT into TESTTYPE (KEY, VALUE) values (:1, :2)", "FOO", "BAR")
+        db._db_rollback()
+        items = db._db_execute("SELECT * from TESTTYPE")
+        self.assertEqual(items, [])
+
+    def test_close(self):
+        """
+        Close database
+        """
+        db = SQL.TestDB(self.mktemp())
+        self.assertFalse(hasattr(db, "_db_connection"))
+        self.assertTrue(db._db() is not None)
+        db._db_close()
+        self.assertFalse(hasattr(db, "_db_connection"))
+        db._db_close()

-------------- next part --------------
An HTML attachment was scrubbed...
URL: http://lists.macosforge.org/pipermail/calendarserver-changes/attachments/20070713/d1ce2a69/attachment.html


More information about the calendarserver-changes mailing list