%PDF- %PDF-
Mini Shell

Mini Shell

Direktori : /usr/lib/python3/dist-packages/twisted/test/
Upload File :
Create Path :
Current File : //usr/lib/python3/dist-packages/twisted/test/test_adbapi.py

# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

"""
Tests for twisted.enterprise.adbapi.
"""

import os
import stat
from typing import Dict, Optional

from twisted.enterprise.adbapi import (
    Connection,
    ConnectionLost,
    ConnectionPool,
    Transaction,
)
from twisted.internet import defer, interfaces, reactor
from twisted.python.failure import Failure
from twisted.python.reflect import requireModule
from twisted.trial import unittest

simple_table_schema = """
CREATE TABLE simple (
  x integer
)
"""


class ADBAPITestBase:
    """
    Test the asynchronous DB-API code.
    """

    openfun_called: Dict[object, bool] = {}

    if interfaces.IReactorThreads(reactor, None) is None:
        skip = "ADB-API requires threads, no way to test without them"

    def extraSetUp(self):
        """
        Set up the database and create a connection pool pointing at it.
        """
        self.startDB()
        self.dbpool = self.makePool(cp_openfun=self.openfun)
        self.dbpool.start()

    def tearDown(self):
        d = self.dbpool.runOperation("DROP TABLE simple")
        d.addCallback(lambda res: self.dbpool.close())
        d.addCallback(lambda res: self.stopDB())
        return d

    def openfun(self, conn):
        self.openfun_called[conn] = True

    def checkOpenfunCalled(self, conn=None):
        if not conn:
            self.assertTrue(self.openfun_called)
        else:
            self.assertIn(conn, self.openfun_called)

    def test_pool(self):
        d = self.dbpool.runOperation(simple_table_schema)
        if self.test_failures:
            d.addCallback(self._testPool_1_1)
            d.addCallback(self._testPool_1_2)
            d.addCallback(self._testPool_1_3)
            d.addCallback(self._testPool_1_4)
            d.addCallback(lambda res: self.flushLoggedErrors())
        d.addCallback(self._testPool_2)
        d.addCallback(self._testPool_3)
        d.addCallback(self._testPool_4)
        d.addCallback(self._testPool_5)
        d.addCallback(self._testPool_6)
        d.addCallback(self._testPool_7)
        d.addCallback(self._testPool_8)
        d.addCallback(self._testPool_9)
        return d

    def _testPool_1_1(self, res):
        d = defer.maybeDeferred(self.dbpool.runQuery, "select * from NOTABLE")
        d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
        return d

    def _testPool_1_2(self, res):
        d = defer.maybeDeferred(self.dbpool.runOperation, "deletexxx from NOTABLE")
        d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
        return d

    def _testPool_1_3(self, res):
        d = defer.maybeDeferred(self.dbpool.runInteraction, self.bad_interaction)
        d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
        return d

    def _testPool_1_4(self, res):
        d = defer.maybeDeferred(self.dbpool.runWithConnection, self.bad_withConnection)
        d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
        return d

    def _testPool_2(self, res):
        # verify simple table is empty
        sql = "select count(1) from simple"
        d = self.dbpool.runQuery(sql)

        def _check(row):
            self.assertTrue(int(row[0][0]) == 0, "Interaction not rolled back")
            self.checkOpenfunCalled()

        d.addCallback(_check)
        return d

    def _testPool_3(self, res):
        sql = "select count(1) from simple"
        inserts = []
        # add some rows to simple table (runOperation)
        for i in range(self.num_iterations):
            sql = "insert into simple(x) values(%d)" % i
            inserts.append(self.dbpool.runOperation(sql))
        d = defer.gatherResults(inserts)

        def _select(res):
            # make sure they were added (runQuery)
            sql = "select x from simple order by x"
            d = self.dbpool.runQuery(sql)
            return d

        d.addCallback(_select)

        def _check(rows):
            self.assertTrue(len(rows) == self.num_iterations, "Wrong number of rows")
            for i in range(self.num_iterations):
                self.assertTrue(len(rows[i]) == 1, "Wrong size row")
                self.assertTrue(rows[i][0] == i, "Values not returned.")

        d.addCallback(_check)

        return d

    def _testPool_4(self, res):
        # runInteraction
        d = self.dbpool.runInteraction(self.interaction)
        d.addCallback(lambda res: self.assertEqual(res, "done"))
        return d

    def _testPool_5(self, res):
        # withConnection
        d = self.dbpool.runWithConnection(self.withConnection)
        d.addCallback(lambda res: self.assertEqual(res, "done"))
        return d

    def _testPool_6(self, res):
        # Test a withConnection cannot be closed
        d = self.dbpool.runWithConnection(self.close_withConnection)
        return d

    def _testPool_7(self, res):
        # give the pool a workout
        ds = []
        for i in range(self.num_iterations):
            sql = "select x from simple where x = %d" % i
            ds.append(self.dbpool.runQuery(sql))
        dlist = defer.DeferredList(ds, fireOnOneErrback=True)

        def _check(result):
            for i in range(self.num_iterations):
                self.assertTrue(result[i][1][0][0] == i, "Value not returned")

        dlist.addCallback(_check)
        return dlist

    def _testPool_8(self, res):
        # now delete everything
        ds = []
        for i in range(self.num_iterations):
            sql = "delete from simple where x = %d" % i
            ds.append(self.dbpool.runOperation(sql))
        dlist = defer.DeferredList(ds, fireOnOneErrback=True)
        return dlist

    def _testPool_9(self, res):
        # verify simple table is empty
        sql = "select count(1) from simple"
        d = self.dbpool.runQuery(sql)

        def _check(row):
            self.assertTrue(
                int(row[0][0]) == 0, "Didn't successfully delete table contents"
            )
            self.checkConnect()

        d.addCallback(_check)
        return d

    def checkConnect(self):
        """Check the connect/disconnect synchronous calls."""
        conn = self.dbpool.connect()
        self.checkOpenfunCalled(conn)
        curs = conn.cursor()
        curs.execute("insert into simple(x) values(1)")
        curs.execute("select x from simple")
        res = curs.fetchall()
        self.assertEqual(len(res), 1)
        self.assertEqual(len(res[0]), 1)
        self.assertEqual(res[0][0], 1)
        curs.execute("delete from simple")
        curs.execute("select x from simple")
        self.assertEqual(len(curs.fetchall()), 0)
        curs.close()
        self.dbpool.disconnect(conn)

    def interaction(self, transaction):
        transaction.execute("select x from simple order by x")
        for i in range(self.num_iterations):
            row = transaction.fetchone()
            self.assertTrue(len(row) == 1, "Wrong size row")
            self.assertTrue(row[0] == i, "Value not returned.")
        self.assertIsNone(transaction.fetchone(), "Too many rows")
        return "done"

    def bad_interaction(self, transaction):
        if self.can_rollback:
            transaction.execute("insert into simple(x) values(0)")

        transaction.execute("select * from NOTABLE")

    def withConnection(self, conn):
        curs = conn.cursor()
        try:
            curs.execute("select x from simple order by x")
            for i in range(self.num_iterations):
                row = curs.fetchone()
                self.assertTrue(len(row) == 1, "Wrong size row")
                self.assertTrue(row[0] == i, "Value not returned.")
        finally:
            curs.close()
        return "done"

    def close_withConnection(self, conn):
        conn.close()

    def bad_withConnection(self, conn):
        curs = conn.cursor()
        try:
            curs.execute("select * from NOTABLE")
        finally:
            curs.close()


class ReconnectTestBase:
    """
    Test the asynchronous DB-API code with reconnect.
    """

    if interfaces.IReactorThreads(reactor, None) is None:
        skip = "ADB-API requires threads, no way to test without them"

    def extraSetUp(self):
        """
        Skip the test if C{good_sql} is unavailable.  Otherwise, set up the
        database, create a connection pool pointed at it, and set up a simple
        schema in it.
        """
        if self.good_sql is None:
            raise unittest.SkipTest("no good sql for reconnect test")
        self.startDB()
        self.dbpool = self.makePool(
            cp_max=1, cp_reconnect=True, cp_good_sql=self.good_sql
        )
        self.dbpool.start()
        return self.dbpool.runOperation(simple_table_schema)

    def tearDown(self):
        d = self.dbpool.runOperation("DROP TABLE simple")
        d.addCallback(lambda res: self.dbpool.close())
        d.addCallback(lambda res: self.stopDB())
        return d

    def test_pool(self):
        d = defer.succeed(None)
        d.addCallback(self._testPool_1)
        d.addCallback(self._testPool_2)
        if not self.early_reconnect:
            d.addCallback(self._testPool_3)
        d.addCallback(self._testPool_4)
        d.addCallback(self._testPool_5)
        return d

    def _testPool_1(self, res):
        sql = "select count(1) from simple"
        d = self.dbpool.runQuery(sql)

        def _check(row):
            self.assertTrue(int(row[0][0]) == 0, "Table not empty")

        d.addCallback(_check)
        return d

    def _testPool_2(self, res):
        # reach in and close the connection manually
        list(self.dbpool.connections.values())[0].close()

    def _testPool_3(self, res):
        sql = "select count(1) from simple"
        d = defer.maybeDeferred(self.dbpool.runQuery, sql)
        d.addCallbacks(lambda res: self.fail("no exception"), lambda f: None)
        return d

    def _testPool_4(self, res):
        sql = "select count(1) from simple"
        d = self.dbpool.runQuery(sql)

        def _check(row):
            self.assertTrue(int(row[0][0]) == 0, "Table not empty")

        d.addCallback(_check)
        return d

    def _testPool_5(self, res):
        self.flushLoggedErrors()
        sql = "select * from NOTABLE"  # bad sql
        d = defer.maybeDeferred(self.dbpool.runQuery, sql)
        d.addCallbacks(
            lambda res: self.fail("no exception"),
            lambda f: self.assertFalse(f.check(ConnectionLost)),
        )
        return d


class DBTestConnector:
    """
    A class which knows how to test for the presence of
    and establish a connection to a relational database.

    To enable test cases  which use a central, system database,
    you must create a database named DB_NAME with a user DB_USER
    and password DB_PASS with full access rights to database DB_NAME.
    """

    # used for creating new test cases
    TEST_PREFIX: Optional[str] = None

    DB_NAME = "twisted_test"
    DB_USER = "twisted_test"
    DB_PASS = "twisted_test"

    DB_DIR = None  # directory for database storage

    nulls_ok = True  # nulls supported
    trailing_spaces_ok = True  # trailing spaces in strings preserved
    can_rollback = True  # rollback supported
    test_failures = True  # test bad sql?
    escape_slashes = True  # escape \ in sql?
    good_sql: Optional[str] = ConnectionPool.good_sql
    early_reconnect = True  # cursor() will fail on closed connection
    can_clear = True  # can try to clear out tables when starting

    # number of iterations for test loop (lower this for slow db's)
    num_iterations = 50

    def setUp(self):
        self.DB_DIR = self.mktemp()
        os.mkdir(self.DB_DIR)
        if not self.can_connect():
            raise unittest.SkipTest("%s: Cannot access db" % self.TEST_PREFIX)
        return self.extraSetUp()

    def can_connect(self):
        """Return true if this database is present on the system
        and can be used in a test."""
        raise NotImplementedError()

    def startDB(self):
        """Take any steps needed to bring database up."""
        pass

    def stopDB(self):
        """Bring database down, if needed."""
        pass

    def makePool(self, **newkw):
        """Create a connection pool with additional keyword arguments."""
        args, kw = self.getPoolArgs()
        kw = kw.copy()
        kw.update(newkw)
        return ConnectionPool(*args, **kw)

    def getPoolArgs(self):
        """Return a tuple (args, kw) of list and keyword arguments
        that need to be passed to ConnectionPool to create a connection
        to this database."""
        raise NotImplementedError()


class SQLite3Connector(DBTestConnector):
    """
    Connector that uses the stdlib SQLite3 database support.
    """

    TEST_PREFIX = "SQLite3"
    escape_slashes = False
    num_iterations = 1  # slow

    def can_connect(self):
        if requireModule("sqlite3") is None:
            return False
        else:
            return True

    def startDB(self):
        self.database = os.path.join(self.DB_DIR, self.DB_NAME)
        if os.path.exists(self.database):
            os.unlink(self.database)

    def getPoolArgs(self):
        args = ("sqlite3",)
        kw = {"database": self.database, "cp_max": 1, "check_same_thread": False}
        return args, kw


class PySQLite2Connector(DBTestConnector):
    """
    Connector that uses pysqlite's SQLite database support.
    """

    TEST_PREFIX = "pysqlite2"
    escape_slashes = False
    num_iterations = 1  # slow

    def can_connect(self):
        if requireModule("pysqlite2.dbapi2") is None:
            return False
        else:
            return True

    def startDB(self):
        self.database = os.path.join(self.DB_DIR, self.DB_NAME)
        if os.path.exists(self.database):
            os.unlink(self.database)

    def getPoolArgs(self):
        args = ("pysqlite2.dbapi2",)
        kw = {"database": self.database, "cp_max": 1, "check_same_thread": False}
        return args, kw


class PyPgSQLConnector(DBTestConnector):
    TEST_PREFIX = "PyPgSQL"

    def can_connect(self):
        try:
            from pyPgSQL import PgSQL  # type: ignore[import]
        except BaseException:
            return False
        try:
            conn = PgSQL.connect(
                database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS
            )
            conn.close()
            return True
        except BaseException:
            return False

    def getPoolArgs(self):
        args = ("pyPgSQL.PgSQL",)
        kw = {
            "database": self.DB_NAME,
            "user": self.DB_USER,
            "password": self.DB_PASS,
            "cp_min": 0,
        }
        return args, kw


class PsycopgConnector(DBTestConnector):
    TEST_PREFIX = "Psycopg"

    def can_connect(self):
        try:
            import psycopg  # type: ignore[import]
        except BaseException:
            return False
        try:
            conn = psycopg.connect(
                database=self.DB_NAME, user=self.DB_USER, password=self.DB_PASS
            )
            conn.close()
            return True
        except BaseException:
            return False

    def getPoolArgs(self):
        args = ("psycopg",)
        kw = {
            "database": self.DB_NAME,
            "user": self.DB_USER,
            "password": self.DB_PASS,
            "cp_min": 0,
        }
        return args, kw


class MySQLConnector(DBTestConnector):
    TEST_PREFIX = "MySQL"

    trailing_spaces_ok = False
    can_rollback = False
    early_reconnect = False

    def can_connect(self):
        try:
            import MySQLdb  # type: ignore[import]
        except BaseException:
            return False
        try:
            conn = MySQLdb.connect(
                db=self.DB_NAME, user=self.DB_USER, passwd=self.DB_PASS
            )
            conn.close()
            return True
        except BaseException:
            return False

    def getPoolArgs(self):
        args = ("MySQLdb",)
        kw = {"db": self.DB_NAME, "user": self.DB_USER, "passwd": self.DB_PASS}
        return args, kw


class FirebirdConnector(DBTestConnector):
    TEST_PREFIX = "Firebird"

    test_failures = False  # failure testing causes problems
    escape_slashes = False
    good_sql = None  # firebird doesn't handle failed sql well
    can_clear = False  # firebird is not so good

    num_iterations = 5  # slow

    def can_connect(self):
        if requireModule("kinterbasdb") is None:
            return False
        try:
            self.startDB()
            self.stopDB()
            return True
        except BaseException:
            return False

    def startDB(self):
        import kinterbasdb  # type: ignore[import]

        self.DB_NAME = os.path.join(self.DB_DIR, DBTestConnector.DB_NAME)
        os.chmod(self.DB_DIR, stat.S_IRWXU + stat.S_IRWXG + stat.S_IRWXO)
        sql = 'create database "%s" user "%s" password "%s"'
        sql %= (self.DB_NAME, self.DB_USER, self.DB_PASS)
        conn = kinterbasdb.create_database(sql)
        conn.close()

    def getPoolArgs(self):
        args = ("kinterbasdb",)
        kw = {
            "database": self.DB_NAME,
            "host": "127.0.0.1",
            "user": self.DB_USER,
            "password": self.DB_PASS,
        }
        return args, kw

    def stopDB(self):
        import kinterbasdb

        conn = kinterbasdb.connect(
            database=self.DB_NAME,
            host="127.0.0.1",
            user=self.DB_USER,
            password=self.DB_PASS,
        )
        conn.drop_database()


def makeSQLTests(base, suffix, globals):
    """
    Make a test case for every db connector which can connect.

    @param base: Base class for test case. Additional base classes
                 will be a DBConnector subclass and unittest.TestCase
    @param suffix: A suffix used to create test case names. Prefixes
                   are defined in the DBConnector subclasses.
    """
    connectors = [
        PySQLite2Connector,
        SQLite3Connector,
        PyPgSQLConnector,
        PsycopgConnector,
        MySQLConnector,
        FirebirdConnector,
    ]
    tests = {}
    for connclass in connectors:
        name = connclass.TEST_PREFIX + suffix

        class testcase(connclass, base, unittest.TestCase):
            __module__ = connclass.__module__

        testcase.__name__ = name
        if hasattr(connclass, "__qualname__"):
            testcase.__qualname__ = ".".join(
                connclass.__qualname__.split()[0:-1] + [name]
            )
        tests[name] = testcase

    globals.update(tests)


# PySQLite2Connector SQLite3ADBAPITests PyPgSQLADBAPITests
# PsycopgADBAPITests MySQLADBAPITests FirebirdADBAPITests
makeSQLTests(ADBAPITestBase, "ADBAPITests", globals())

# PySQLite2Connector SQLite3ReconnectTests PyPgSQLReconnectTests
# PsycopgReconnectTests MySQLReconnectTests FirebirdReconnectTests
makeSQLTests(ReconnectTestBase, "ReconnectTests", globals())


class FakePool:
    """
    A fake L{ConnectionPool} for tests.

    @ivar connectionFactory: factory for making connections returned by the
        C{connect} method.
    @type connectionFactory: any callable
    """

    reconnect = True
    noisy = True

    def __init__(self, connectionFactory):
        self.connectionFactory = connectionFactory

    def connect(self):
        """
        Return an instance of C{self.connectionFactory}.
        """
        return self.connectionFactory()

    def disconnect(self, connection):
        """
        Do nothing.
        """


class ConnectionTests(unittest.TestCase):
    """
    Tests for the L{Connection} class.
    """

    def test_rollbackErrorLogged(self):
        """
        If an error happens during rollback, L{ConnectionLost} is raised but
        the original error is logged.
        """

        class ConnectionRollbackRaise:
            def rollback(self):
                raise RuntimeError("problem!")

        pool = FakePool(ConnectionRollbackRaise)
        connection = Connection(pool)
        self.assertRaises(ConnectionLost, connection.rollback)
        errors = self.flushLoggedErrors(RuntimeError)
        self.assertEqual(len(errors), 1)
        self.assertEqual(errors[0].value.args[0], "problem!")


class TransactionTests(unittest.TestCase):
    """
    Tests for the L{Transaction} class.
    """

    def test_reopenLogErrorIfReconnect(self):
        """
        If the cursor creation raises an error in L{Transaction.reopen}, it
        reconnects but log the error occurred.
        """

        class ConnectionCursorRaise:
            count = 0

            def reconnect(self):
                pass

            def cursor(self):
                if self.count == 0:
                    self.count += 1
                    raise RuntimeError("problem!")

        pool = FakePool(None)
        transaction = Transaction(pool, ConnectionCursorRaise())
        transaction.reopen()
        errors = self.flushLoggedErrors(RuntimeError)
        self.assertEqual(len(errors), 1)
        self.assertEqual(errors[0].value.args[0], "problem!")


class NonThreadPool:
    def callInThreadWithCallback(self, onResult, f, *a, **kw):
        success = True
        try:
            result = f(*a, **kw)
        except Exception:
            success = False
            result = Failure()
        onResult(success, result)


class DummyConnectionPool(ConnectionPool):
    """
    A testable L{ConnectionPool};
    """

    threadpool = NonThreadPool()

    def __init__(self):
        """
        Don't forward init call.
        """
        self._reactor = reactor


class EventReactor:
    """
    Partial L{IReactorCore} implementation with simple event-related
    methods.

    @ivar _running: A C{bool} indicating whether the reactor is pretending
        to have been started already or not.

    @ivar triggers: A C{list} of pending system event triggers.
    """

    def __init__(self, running):
        self._running = running
        self.triggers = []

    def callWhenRunning(self, function):
        if self._running:
            function()
        else:
            return self.addSystemEventTrigger("after", "startup", function)

    def addSystemEventTrigger(self, phase, event, trigger):
        handle = (phase, event, trigger)
        self.triggers.append(handle)
        return handle

    def removeSystemEventTrigger(self, handle):
        self.triggers.remove(handle)


class ConnectionPoolTests(unittest.TestCase):
    """
    Unit tests for L{ConnectionPool}.
    """

    def test_runWithConnectionRaiseOriginalError(self):
        """
        If rollback fails, L{ConnectionPool.runWithConnection} raises the
        original exception and log the error of the rollback.
        """

        class ConnectionRollbackRaise:
            def __init__(self, pool):
                pass

            def rollback(self):
                raise RuntimeError("problem!")

        def raisingFunction(connection):
            raise ValueError("foo")

        pool = DummyConnectionPool()
        pool.connectionFactory = ConnectionRollbackRaise
        d = pool.runWithConnection(raisingFunction)
        d = self.assertFailure(d, ValueError)

        def cbFailed(ignored):
            errors = self.flushLoggedErrors(RuntimeError)
            self.assertEqual(len(errors), 1)
            self.assertEqual(errors[0].value.args[0], "problem!")

        d.addCallback(cbFailed)
        return d

    def test_closeLogError(self):
        """
        L{ConnectionPool._close} logs exceptions.
        """

        class ConnectionCloseRaise:
            def close(self):
                raise RuntimeError("problem!")

        pool = DummyConnectionPool()
        pool._close(ConnectionCloseRaise())

        errors = self.flushLoggedErrors(RuntimeError)
        self.assertEqual(len(errors), 1)
        self.assertEqual(errors[0].value.args[0], "problem!")

    def test_runWithInteractionRaiseOriginalError(self):
        """
        If rollback fails, L{ConnectionPool.runInteraction} raises the
        original exception and log the error of the rollback.
        """

        class ConnectionRollbackRaise:
            def __init__(self, pool):
                pass

            def rollback(self):
                raise RuntimeError("problem!")

        class DummyTransaction:
            def __init__(self, pool, connection):
                pass

        def raisingFunction(transaction):
            raise ValueError("foo")

        pool = DummyConnectionPool()
        pool.connectionFactory = ConnectionRollbackRaise
        pool.transactionFactory = DummyTransaction

        d = pool.runInteraction(raisingFunction)
        d = self.assertFailure(d, ValueError)

        def cbFailed(ignored):
            errors = self.flushLoggedErrors(RuntimeError)
            self.assertEqual(len(errors), 1)
            self.assertEqual(errors[0].value.args[0], "problem!")

        d.addCallback(cbFailed)
        return d

    def test_unstartedClose(self):
        """
        If L{ConnectionPool.close} is called without L{ConnectionPool.start}
        having been called, the pool's startup event is cancelled.
        """
        reactor = EventReactor(False)
        pool = ConnectionPool("twisted.test.test_adbapi", cp_reactor=reactor)
        # There should be a startup trigger waiting.
        self.assertEqual(reactor.triggers, [("after", "startup", pool._start)])
        pool.close()
        # But not anymore.
        self.assertFalse(reactor.triggers)

    def test_startedClose(self):
        """
        If L{ConnectionPool.close} is called after it has been started, but
        not by its shutdown trigger, the shutdown trigger is cancelled.
        """
        reactor = EventReactor(True)
        pool = ConnectionPool("twisted.test.test_adbapi", cp_reactor=reactor)
        # There should be a shutdown trigger waiting.
        self.assertEqual(reactor.triggers, [("during", "shutdown", pool.finalClose)])
        pool.close()
        # But not anymore.
        self.assertFalse(reactor.triggers)

Zerion Mini Shell 1.0