Exemple #1
0
def do_full_backup(options):
    options.full = True
    dest = os.path.join(options.repository, gen_filename(options))
    if os.path.exists(dest):
        raise WouldOverwriteFiles('Cannot overwrite existing file: %s' % dest)
    # Find the file position of the last completed transaction.
    fs = FileStorage(options.file, read_only=True)
    # Note that the FileStorage ctor calls read_index() which scans the file
    # and returns "the position just after the last valid transaction record".
    # getSize() then returns this position, which is exactly what we want,
    # because we only want to copy stuff from the beginning of the file to the
    # last valid transaction record.
    pos = fs.getSize()
    # Save the storage index into the repository
    index_file = os.path.join(options.repository,
                              gen_filename(options, '.index'))
    log('writing index')
    fs._index.save(pos, index_file)
    fs.close()
    log('writing full backup: %s bytes to %s', pos, dest)
    sum = copyfile(options, dest, 0, pos)
    # Write the data file for this full backup
    datfile = os.path.splitext(dest)[0] + '.dat'
    fp = open(datfile, 'w')
    print >> fp, dest, 0, pos, sum
    fp.flush()
    os.fsync(fp.fileno())
    fp.close()
    if options.killold:
        delete_old_backups(options)
Exemple #2
0
    def test_pack_with_1_day(self):
        from ZODB.DB import DB
        from ZODB.FileStorage import FileStorage
        import time
        import transaction
        from relstorage.zodbpack import main

        storage = FileStorage(self.db_fn, create=True)
        db = DB(storage)
        conn = db.open()
        conn.root()['x'] = 1
        transaction.commit()
        oid = b'\0' * 8
        state, serial = storage.load(oid, '')
        time.sleep(0.1)
        conn.root()['x'] = 2
        transaction.commit()
        conn.close()
        self.assertEqual(state, storage.loadSerial(oid, serial))
        db.close()
        storage = None

        main(['', '--days=1', self.cfg_fn])

        # packing should not have removed the old state.
        storage = FileStorage(self.db_fn)
        self.assertEqual(state, storage.loadSerial(oid, serial))
        storage.close()
Exemple #3
0
def do_incremental_backup(options, reposz, repofiles):
    options.full = False
    dest = os.path.join(options.repository, gen_filename(options))
    if os.path.exists(dest):
        raise WouldOverwriteFiles('Cannot overwrite existing file: %s' % dest)
    # Find the file position of the last completed transaction.
    fs = FileStorage(options.file, read_only=True)
    # Note that the FileStorage ctor calls read_index() which scans the file
    # and returns "the position just after the last valid transaction record".
    # getSize() then returns this position, which is exactly what we want,
    # because we only want to copy stuff from the beginning of the file to the
    # last valid transaction record.
    pos = fs.getSize()
    log('writing index')
    index_file = os.path.join(options.repository,
                              gen_filename(options, '.index'))
    fs._index.save(pos, index_file)
    fs.close()
    log('writing incremental: %s bytes to %s',  pos-reposz, dest)
    sum = copyfile(options, dest, reposz, pos - reposz)
    # The first file in repofiles points to the last full backup.  Use this to
    # get the .dat file and append the information for this incrementatl to
    # that file.
    fullfile = repofiles[0]
    datfile = os.path.splitext(fullfile)[0] + '.dat'
    # This .dat file better exist.  Let the exception percolate if not.
    fp = open(datfile, 'a')
    print >> fp, dest, reposz, pos, sum
    fp.flush()
    os.fsync(fp.fileno())
    fp.close()
Exemple #4
0
def do_incremental_backup(options, reposz, repofiles):
    options.full = False
    dest = os.path.join(options.repository, gen_filename(options))
    if os.path.exists(dest):
        raise WouldOverwriteFiles('Cannot overwrite existing file: %s' % dest)
    # Find the file position of the last completed transaction.
    fs = FileStorage(options.file, read_only=True)
    # Note that the FileStorage ctor calls read_index() which scans the file
    # and returns "the position just after the last valid transaction record".
    # getSize() then returns this position, which is exactly what we want,
    # because we only want to copy stuff from the beginning of the file to the
    # last valid transaction record.
    pos = fs.getSize()
    log('writing index')
    index_file = os.path.join(options.repository,
                              gen_filename(options, '.index'))
    fs._index.save(pos, index_file)
    fs.close()
    log('writing incremental: %s bytes to %s',  pos-reposz, dest)
    sum = copyfile(options, dest, reposz, pos - reposz)
    # The first file in repofiles points to the last full backup.  Use this to
    # get the .dat file and append the information for this incrementatl to
    # that file.
    fullfile = repofiles[0]
    datfile = os.path.splitext(fullfile)[0] + '.dat'
    # This .dat file better exist.  Let the exception percolate if not.
    fp = open(datfile, 'a')
    print >> fp, dest, reposz, pos, sum
    fp.flush()
    os.fsync(fp.fileno())
    fp.close()
 def tearDown(self):
     self.storage.close()
     if self.recovered is not None:
         self.recovered.close()
     temp = FileStorage(self.dest)
     temp.close()
     ZODB.tests.util.TestCase.tearDown(self)
class PackerTests(StorageTestBase):

    def setUp(self):
        self.started = 0

    def start(self):
        self.started =1
        self.path = tempfile.mktemp(suffix=".fs")
        self._storage = FileStorage(self.path)
        self.db = ZODB.DB(self._storage)
        self.do_updates()
        self.pid, self.exit = forker.start_zeo_server(self._storage, self.addr)

    def do_updates(self):
        for i in range(100):
            self._dostore()

    def tearDown(self):
        if not self.started:
            return
        self.db.close()
        self._storage.close()
        self.exit.close()
        try:
            os.kill(self.pid, 9)
        except os.error:
            pass
        try:
            os.waitpid(self.pid, 0)
        except os.error, err:
            ##print "waitpid failed", err
            pass
        removefs(self.path)
Exemple #7
0
def do_full_backup(options):
    options.full = True
    dest = os.path.join(options.repository, gen_filename(options))
    if os.path.exists(dest):
        raise WouldOverwriteFiles('Cannot overwrite existing file: %s' % dest)
    # Find the file position of the last completed transaction.
    fs = FileStorage(options.file, read_only=True)
    # Note that the FileStorage ctor calls read_index() which scans the file
    # and returns "the position just after the last valid transaction record".
    # getSize() then returns this position, which is exactly what we want,
    # because we only want to copy stuff from the beginning of the file to the
    # last valid transaction record.
    pos = fs.getSize()
    # Save the storage index into the repository
    index_file = os.path.join(options.repository,
                              gen_filename(options, '.index'))
    log('writing index')
    fs._index.save(pos, index_file)
    fs.close()
    log('writing full backup: %s bytes to %s', pos, dest)
    sum = copyfile(options, dest, 0, pos)
    # Write the data file for this full backup
    datfile = os.path.splitext(dest)[0] + '.dat'
    fp = open(datfile, 'w')
    print >> fp, dest, 0, pos, sum
    fp.flush()
    os.fsync(fp.fileno())
    fp.close()
    if options.killold:
        delete_old_backups(options)
Exemple #8
0
 def tearDown(self):
     self.storage.close()
     if self.recovered is not None:
         self.recovered.close()
     temp = FileStorage(self.dest)
     temp.close()
     ZODB.tests.util.TestCase.tearDown(self)
Exemple #9
0
    def test_pack_with_1_day(self):
        from ZODB.DB import DB
        from ZODB.FileStorage import FileStorage
        from ZODB.POSException import POSKeyError
        import time
        import transaction
        from relstorage.zodbpack import main

        storage = FileStorage(self.db_fn, create=True)
        db = DB(storage)
        conn = db.open()
        conn.root()['x'] = 1
        transaction.commit()
        oid = b('\0' * 8)
        state, serial = storage.load(oid, b(''))
        time.sleep(0.1)
        conn.root()['x'] = 2
        transaction.commit()
        conn.close()
        self.assertEqual(state, storage.loadSerial(oid, serial))
        db.close()
        storage = None

        main(['', '--days=1', self.cfg_fn])

        # packing should not have removed the old state.
        storage = FileStorage(self.db_fn)
        self.assertEqual(state, storage.loadSerial(oid, serial))
        storage.close()
Exemple #10
0
class DbAdapter:
    def __init__(self, path="data.db"):
        self.path = path

    def connect(self):
        self.storage = FileStorage(self.path)
        self.db = DB(self.storage)
        self.conn = self.db.open()
        return self.conn.root()

    def begin_transaction(self):
        transaction.begin()

    def commit(self):
        transaction.commit()

    def rollback(self):
        transaction.abort()

    def disconnect(self):
        self.conn.close()
        self.db.close()
        self.storage.close()
        if os.path.exists(self.path + ".lock"):
            os.remove(self.path + ".lock")
Exemple #11
0
class HistoryFreeFromFileStorage(
        RelStorageTestBase,
        UndoableRecoveryStorage,
):

    keep_history = False

    def setUp(self):
        self.open(create=1)
        self._storage.zap_all()
        self._dst = self._storage
        self._storage = FileStorage("Source.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return self._dst

    def compare(self, src, dest):
        # The dest storage has a truncated copy of dest, so
        # use compare_truncated() instead of compare_exact().
        self.compare_truncated(src, dest)
Exemple #12
0
    def test_pack_defaults(self):
        from ZODB.DB import DB
        from ZODB.FileStorage import FileStorage
        from ZODB.POSException import POSKeyError
        import time
        import transaction
        from relstorage.zodbpack import main

        storage = FileStorage(self.db_fn, create=True)
        db = DB(storage)
        conn = db.open()
        conn.root()['x'] = 1
        transaction.commit()
        oid = b'\0' * 8
        state, serial = storage.load(oid, '')
        time.sleep(0.1)
        conn.root()['x'] = 2
        transaction.commit()
        conn.close()
        self.assertEqual(state, storage.loadSerial(oid, serial))
        db.close()
        storage = None

        main(['', self.cfg_fn])

        # packing should have removed the old state.
        storage = FileStorage(self.db_fn)
        self.assertRaises(POSKeyError, storage.loadSerial, oid, serial)
        storage.close()
Exemple #13
0
class HistoryFreeFromFileStorage(
        RelStorageTestBase,
        UndoableRecoveryStorage,
        ):

    keep_history = False

    def setUp(self):
        self.open(create=1)
        self._storage.zap_all()
        self._dst = self._storage
        self._storage = FileStorage("Source.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return self._dst

    def compare(self, src, dest):
        # The dest storage has a truncated copy of dest, so
        # use compare_truncated() instead of compare_exact().
        self.compare_truncated(src, dest)
Exemple #14
0
class Base(object):

    def __init__(self, path, authkey):
        if not os.path.exists(path):
            os.makedirs(path)
        self._path = path
        self.authkey = authkey

        path = os.path.join(path, 'graph.fs')
        self.storage = FileStorage(path)
        self.db = DB(self.storage)

    def path(self):
        return self._path

    def process(self, connection):
        (func, args) = connection
        self.connection = func(*args)

    def recv(self):
        return self.connection.recv()

    def send(self, message):
        self.connection.send(message)
        self.connection.close()

    def open(self):
        return self.db.open()

    def close(self):
        transaction.get().abort()
        self.db.close()
        self.storage.close()
Exemple #15
0
 def tearDown(self):
     self.storage.close()
     if self.recovered is not None:
         self.recovered.close()
     self.storage.cleanup()
     temp = FileStorage(self.dest)
     temp.close()
     temp.cleanup()
 def tearDown(self):
     fsetup = functional.FunctionalTestSetup()
     # close the filestorage files now by calling the original
     # close on our storage instance
     FileStorage.close(fsetup.base_storage)
     fsetup.base_storage = self.original
     fsetup.tearDown()
     fsetup.tearDownCompletely()
Exemple #17
0
 def tearDown(self):
     self.storage.close()
     if self.recovered is not None:
         self.recovered.close()
     self.storage.cleanup()
     temp = FileStorage(self.dest)
     temp.close()
     temp.cleanup()
Exemple #18
0
    def checkBackwardTimeTravelWithRevertWhenStale(self):
        # If revert_when_stale is true, when the database
        # connection is stale (such as through failover to an
        # asynchronous slave that is not fully up to date), the poller
        # should notice that backward time travel has occurred and
        # invalidate all objects that have changed in the interval.
        self._storage = self.make_storage(revert_when_stale=True)

        import os
        import shutil
        import tempfile
        from ZODB.FileStorage import FileStorage

        db = DB(self._storage)
        try:
            transaction.begin()
            c = db.open()
            r = c.root()
            r["alpha"] = PersistentMapping()
            transaction.commit()

            # To simulate failover to an out of date async slave, take
            # a snapshot of the database at this point, change some
            # object, then restore the database to its earlier state.

            d = tempfile.mkdtemp()
            try:
                transaction.begin()
                fs = FileStorage(os.path.join(d, "Data.fs"))
                fs.copyTransactionsFrom(c._storage)

                r["beta"] = PersistentMapping()
                transaction.commit()
                self.assertTrue("beta" in r)

                c._storage.zap_all(reset_oid=False, slow=True)
                c._storage.copyTransactionsFrom(fs)

                fs.close()
            finally:
                shutil.rmtree(d)

            # r should still be in the cache.
            self.assertTrue("beta" in r)

            # Now sync, which will call poll_invalidations().
            c.sync()

            # r should have been invalidated
            self.assertEqual(r._p_changed, None)

            # r should be reverted to its earlier state.
            self.assertFalse("beta" in r)

        finally:
            db.close()
Exemple #19
0
 def tearDown(self):
     fsetup = functional.FunctionalTestSetup(self.config_file)
     # close the filestorage files now by calling the original
     # close on our storage instance
     FileStorage.close(fsetup.base_storage)
     # replace the storage with the original, so functionalsetup
     # can do what it wants with it
     fsetup.base_storage = self.original
     fsetup.tearDown()
     fsetup.tearDownCompletely()
Exemple #20
0
    def checkBackwardTimeTravelWithRevertWhenStale(self):
        # If revert_when_stale is true, when the database
        # connection is stale (such as through failover to an
        # asynchronous slave that is not fully up to date), the poller
        # should notice that backward time travel has occurred and
        # invalidate all objects that have changed in the interval.
        self._storage = self.make_storage(revert_when_stale=True)

        import os
        import shutil
        import tempfile
        from ZODB.FileStorage import FileStorage
        db = DB(self._storage)
        try:
            transaction.begin()
            c = db.open()
            r = c.root()
            r['alpha'] = PersistentMapping()
            transaction.commit()

            # To simulate failover to an out of date async slave, take
            # a snapshot of the database at this point, change some
            # object, then restore the database to its earlier state.

            d = tempfile.mkdtemp()
            try:
                transaction.begin()
                fs = FileStorage(os.path.join(d, 'Data.fs'))
                fs.copyTransactionsFrom(c._storage)

                r['beta'] = PersistentMapping()
                transaction.commit()
                self.assertTrue('beta' in r)

                c._storage.zap_all(reset_oid=False, slow=True)
                c._storage.copyTransactionsFrom(fs)

                fs.close()
            finally:
                shutil.rmtree(d)

            # r should still be in the cache.
            self.assertTrue('beta' in r)

            # Now sync, which will call poll_invalidations().
            c.sync()

            # r should have been invalidated
            self.assertEqual(r._p_changed, None)

            # r should be reverted to its earlier state.
            self.assertFalse('beta' in r)

        finally:
            db.close()
Exemple #21
0
class MyZODB(object):
    def __init__(self, path):
        self.storage = FileStorage(path)
        self.db = DB(self.storage)
        self.connection = self.db.open()
        self.dbroot = self.connection.root()

    def close(self):
        self.connection.close()
        self.db.close()
        self.storage.close()
Exemple #22
0
class TestBackend:

  def __init__(self, filename, mode):

    self.mode = mode

    if mode == "w":
      self.storage = FileStorage(filename)
      db = DB(self.storage)
      connection = db.open()
      self.test_db_items = connection.root()

    elif mode == "r":
      self.storage = FileStorage(filename)
      db = DB(self.storage)
      connection = db.open()
      self.test_db_items = connection.root()

      self.next_rec_num = 0   # Initialise next record counter
      self.num_records = len(self.test_db_items)

  def __setitem__(self, key, value):

    self.test_db_items[key] = value

  def __getitem__(self, key):

    return self.test_db_items[str(key)]

  def __len__(self):

    return len(self.test_db_items)

  def first(self):

    return self.test_db_items[0]

  def iteritems(self):

    while(self.next_rec_num < self.num_records):
      value = self.test_db_items[self.next_rec_num]
  
      self.next_rec_num += 1

      yield value

  def close(self):
    transaction.commit()
    self.storage.close()

  def getTestDBItems(self):
    return self.test_db_items.values()
Exemple #23
0
    def checkBackwardTimeTravel(self):
        # When a failover event causes the storage to switch to an
        # asynchronous slave that is not fully up to date, the poller
        # should notice that backward time travel has occurred and
        # handle the situation by invalidating all objects that have
        # changed in the interval. (Currently, we simply invalidate all
        # objects when backward time travel occurs.)
        import os
        import shutil
        import tempfile
        from ZODB.FileStorage import FileStorage
        db = DB(self._storage)
        try:
            c = db.open()
            r = c.root()
            r['alpha'] = PersistentMapping()
            transaction.commit()

            # To simulate failover to an out of date async slave, take
            # a snapshot of the database at this point, change some
            # object, then restore the database to its earlier state.

            d = tempfile.mkdtemp()
            try:
                fs = FileStorage(os.path.join(d, 'Data.fs'))
                fs.copyTransactionsFrom(c._storage)

                r['beta'] = PersistentMapping()
                transaction.commit()
                self.assertTrue('beta' in r)

                c._storage.zap_all()
                c._storage.copyTransactionsFrom(fs)

                fs.close()
            finally:
                shutil.rmtree(d)

            # r should still be in the cache.
            self.assertTrue('beta' in r)

            # Now sync, which will call poll_invalidations().
            c.sync()

            # r should have been invalidated
            self.assertEqual(r._p_changed, None)

            # r should be reverted to its earlier state.
            self.assertFalse('beta' in r)

        finally:
            db.close()
Exemple #24
0
class EventCollection(object):
    """
    Structure to store an ensemble of events to disk and utilities to
    iterate through the events.
    """

    events_since_save = 0
    storage = None
    db = None
    connection = None
    store = None
    events_since_save = 0

    def __init__(self, filename):
        self.filename = filename
        self.open()

    def __enter__(self):
        pass

    def __exit__(self, type, value, traceback):
        self.close()

    def open(self):
        self.storage = FileStorage(self.filename)
        self.db = DB(self.storage)
        self.connection = self.db.open()
        self.store = self.connection.root()
        self.events_since_save = 0
        return self

    def close(self):
        self.connection.close()
        self.storage.close()

    def new_key(self):
        return max(self.store.keys())+1 if self.store.keys() else 0

    def save(self):
        transaction.commit()

    def events(self):
        for key in self.store.keys():
            yield self.store[key]

    def add_event(self, event):
        self.store[self.new_key()] = event
        self.events_since_save += 1
        if self.events_since_save > 10000:
            print "Saving..."
            self.events_since_save = 0
            self.save()
Exemple #25
0
    def checkBackwardTimeTravel(self):
        # When a failover event causes the storage to switch to an
        # asynchronous slave that is not fully up to date, the poller
        # should notice that backward time travel has occurred and
        # handle the situation by invalidating all objects that have
        # changed in the interval. (Currently, we simply invalidate all
        # objects when backward time travel occurs.)
        import os
        import shutil
        import tempfile
        from ZODB.FileStorage import FileStorage
        db = DB(self._storage)
        try:
            c = db.open()
            r = c.root()
            r['alpha'] = PersistentMapping()
            transaction.commit()

            # To simulate failover to an out of date async slave, take
            # a snapshot of the database at this point, change some
            # object, then restore the database to its earlier state.

            d = tempfile.mkdtemp()
            try:
                fs = FileStorage(os.path.join(d, 'Data.fs'))
                fs.copyTransactionsFrom(c._storage)

                r['beta'] = PersistentMapping()
                transaction.commit()
                self.assertTrue('beta' in r)

                c._storage.zap_all()
                c._storage.copyTransactionsFrom(fs)

                fs.close()
            finally:
                shutil.rmtree(d)

            # r should still be in the cache.
            self.assertTrue('beta' in r)

            # Now sync, which will call poll_invalidations().
            c.sync()

            # r should have been invalidated
            self.assertEqual(r._p_changed, None)

            # r should be reverted to its earlier state.
            self.assertFalse('beta' in r)

        finally:
            db.close()
Exemple #26
0
    def checkBackwardTimeTravelWithoutRevertWhenStale(self):
        # If revert_when_stale is false (the default), when the database
        # connection is stale (such as through failover to an
        # asynchronous slave that is not fully up to date), the poller
        # should notice that backward time travel has occurred and
        # raise a ReadConflictError.
        self._storage = self.make_storage(revert_when_stale=False)

        import os
        import shutil
        import tempfile
        from ZODB.FileStorage import FileStorage

        db = DB(self._storage)
        try:
            c = db.open()
            r = c.root()
            r["alpha"] = PersistentMapping()
            transaction.commit()

            # To simulate failover to an out of date async slave, take
            # a snapshot of the database at this point, change some
            # object, then restore the database to its earlier state.

            d = tempfile.mkdtemp()
            try:
                fs = FileStorage(os.path.join(d, "Data.fs"))
                fs.copyTransactionsFrom(c._storage)

                r["beta"] = PersistentMapping()
                transaction.commit()
                self.assertTrue("beta" in r)

                c._storage.zap_all(reset_oid=False, slow=True)
                c._storage.copyTransactionsFrom(fs)

                fs.close()
            finally:
                shutil.rmtree(d)

            # Sync, which will call poll_invalidations().
            c.sync()

            # Try to load an object, which should cause ReadConflictError.
            r._p_deactivate()
            self.assertRaises(ReadConflictError, lambda: r["beta"])

        finally:
            db.close()
Exemple #27
0
    def checkBackwardTimeTravelWithoutRevertWhenStale(self):
        # If revert_when_stale is false (the default), when the database
        # connection is stale (such as through failover to an
        # asynchronous slave that is not fully up to date), the poller
        # should notice that backward time travel has occurred and
        # raise a ReadConflictError.
        self._storage = self.make_storage(revert_when_stale=False)

        import os
        import shutil
        import tempfile
        from ZODB.FileStorage import FileStorage
        db = DB(self._storage)
        try:
            c = db.open()
            r = c.root()
            r['alpha'] = PersistentMapping()
            transaction.commit()

            # To simulate failover to an out of date async slave, take
            # a snapshot of the database at this point, change some
            # object, then restore the database to its earlier state.

            d = tempfile.mkdtemp()
            try:
                fs = FileStorage(os.path.join(d, 'Data.fs'))
                fs.copyTransactionsFrom(c._storage)

                r['beta'] = PersistentMapping()
                transaction.commit()
                self.assertTrue('beta' in r)

                c._storage.zap_all(reset_oid=False)
                c._storage.copyTransactionsFrom(fs)

                fs.close()
            finally:
                shutil.rmtree(d)

            # Sync, which will call poll_invalidations().
            c.sync()

            # Try to load an object, which should cause ReadConflictError.
            r._p_deactivate()
            self.assertRaises(ReadConflictError, lambda: r['beta'])

        finally:
            db.close()
Exemple #28
0
class HistoryFreeToFileStorage(RelStorageTestBase, BasicRecoveryStorage):

    keep_history = False

    def setUp(self):
        self._storage = self.make_storage()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return FileStorage("Dest.fs")
Exemple #29
0
class HistoryFreeToFileStorage(RelStorageTestBase, BasicRecoveryStorage):
    # pylint:disable=abstract-method,too-many-ancestors
    keep_history = False

    def setUp(self):
        super(HistoryFreeToFileStorage, self).setUp()
        self._storage = self.make_storage()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._dst.close()
        self._dst.cleanup()
        super(HistoryFreeToFileStorage, self).tearDown()

    def new_dest(self):
        return FileStorage('Dest.fs')
Exemple #30
0
class HistoryFreeToFileStorage(RelStorageTestBase,
                               BasicRecoveryStorage):
    # pylint:disable=abstract-method,too-many-ancestors
    keep_history = False

    def setUp(self):
        self._storage = self.make_storage()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return FileStorage('Dest.fs')
Exemple #31
0
class HistoryPreservingToFileStorage(RelStorageTestBase,
                                     UndoableRecoveryStorage):
    # pylint:disable=too-many-ancestors,abstract-method,too-many-locals
    keep_history = True

    def setUp(self):
        super(HistoryPreservingToFileStorage, self).setUp()
        self._storage = self.make_storage()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._dst.close()
        self._dst.cleanup()
        super(HistoryPreservingToFileStorage, self).tearDown()

    def new_dest(self):
        return FileStorage('Dest.fs')
Exemple #32
0
class HistoryPreservingFromFileStorage(
        RelStorageTestBase,
        UndoableRecoveryStorage,
):

    keep_history = True

    def setUp(self):
        self._dst = self.make_storage()
        self._storage = FileStorage("Source.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return self._dst
Exemple #33
0
class HistoryFreeToFileStorage(
        RelStorageTestBase,
        BasicRecoveryStorage,
):

    keep_history = False

    def setUp(self):
        self._storage = self.make_storage()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return FileStorage('Dest.fs')
Exemple #34
0
class HistoryPreservingFromFileStorage(
        RelStorageTestBase,
        UndoableRecoveryStorage,
    ):

    keep_history = True

    def setUp(self):
        self._dst = self.make_storage()
        self._storage = FileStorage("Source.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return self._dst
class ZODBConnector(object):
    # from ZODB.FileStorage import FileStorage
    # FileStorage('cache.fs')
    path = '/var/filestorage/cache.fs'

    def __init__(self, path=None):
        if path == None:
            path = self.path

        try:
            self.storage = FileStorage(path)
            self.db = DB(self.storage)
            self.connection = self.db.open()
            self.root = self.connection.root()
        except Exception as e:
            logger.error("ZODB CONNECTOR ERROR: " + str(e))

    def getData(self, name):
        try:
            item = self.root[name]
            transaction.commit()
            return item
        except Exception as e:
            logger.error("GETTING DATA DB --- RETURNING NONE!!!! Data: " +
                         str(e))
            return None

    def setData(self, name, value):
        try:
            self.root[name] = value
            transaction.commit()
        except Exception as e:
            transaction.abort()
            logger.error("Set Data Error: " + str(e))

    def close(self):
        try:
            self.connection.close()
            self.db.close()
            self.storage.close()
        except Exception as e:
            logger.error("CLOSING DB ERROR" + str(e))
Exemple #36
0
class HistoryPreservingToFileStorage(
        RelStorageTestBase,
        UndoableRecoveryStorage,
        ):

    keep_history = True

    def setUp(self):
        self.open(create=1)
        self._storage.zap_all()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return FileStorage('Dest.fs')
Exemple #37
0
class HistoryPreservingToFileStorage(
        RelStorageTestBase,
        UndoableRecoveryStorage,
):

    keep_history = True

    def setUp(self):
        self.open(create=1)
        self._storage.zap_all()
        self._dst = FileStorage("Dest.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return FileStorage('Dest.fs')
Exemple #38
0
class HistoryFreeFromFileStorage(RelStorageTestBase,
                                 UndoableRecoveryStorage):
    # pylint:disable=abstract-method,too-many-ancestors
    keep_history = False

    def setUp(self):
        self._dst = self._storage
        self._storage = FileStorage("Source.fs", create=True)

    def tearDown(self):
        self._storage.close()
        self._dst.close()
        self._storage.cleanup()
        self._dst.cleanup()

    def new_dest(self):
        return self._dst

    def compare(self, src, dest):
        # The dest storage has a truncated copy of dest, so
        # use compare_truncated() instead of compare_exact().
        self.compare_truncated(src, dest)
Exemple #39
0
class PackerTests(StorageTestBase):
    def setUp(self):
        self.started = 0

    def start(self):
        self.started = 1
        self.path = tempfile.mktemp(suffix=".fs")
        self._storage = FileStorage(self.path)
        self.db = ZODB.DB(self._storage)
        self.do_updates()
        self.pid, self.exit = forker.start_zeo_server(self._storage, self.addr)

    def do_updates(self):
        for i in range(100):
            self._dostore()

    def tearDown(self):
        if not self.started:
            return
        self.db.close()
        self._storage.close()
        self.exit.close()
        try:
            os.kill(self.pid, 9)
        except os.error:
            pass
        try:
            os.waitpid(self.pid, 0)
        except os.error, err:
            ##print "waitpid failed", err
            pass
        for ext in '', '.old', '.lock', '.index', '.tmp':
            path = self.path + ext
            try:
                os.remove(path)
            except os.error:
                pass
Exemple #40
0
class CacheDB(object):
    """
    A caching ZODB
    """
    def __init__(self, filename):
        filename = str(filename)
        self.__storage = FileStorage(filename)
        self.__db = zodb.DB(self.__storage)
        self.__connection = self.__db.open()

    @property
    def storage(self):
        return self.__storage

    @property
    def db(self):
        return self.__db

    @property
    def connection(self):
        return self.__connection

    @property
    def root(self):
        return self.connection.root

    def commit(self):
        zodb_transact.commit()

    def close(self):
        self.__connection.close()
        self.__db.close()
        self.__storage.close()
        self.__connection = None
        self.__db = None
        self.__storage = None
Exemple #41
0
    for o, v in opts:
        if o == '-v':
            VERBOSE += 1
        if o == '-f':
            FSPATH = v
        if o == '-t':
            TXN_INTERVAL = int(v)
        if o == '-p':
            PACK_INTERVAL = int(v)
        if o == '-n':
            LIMIT = int(v)
#        if o == '-T':
#            INDEX = make_old_index

    if len(args) != 1:
        print "Expected on argument"
        print __doc__
        sys.exit(2)
    dir = args[0]

    fs = FileStorage(FSPATH)
    db = ZODB.DB(fs)
    cn = db.open()
    rt = cn.root()
    dir = os.path.join(os.getcwd(), dir)
    print dir
    main(db, rt, dir)
    cn.close()
    fs.close()
Exemple #42
0
class HistoryTests(unittest.TestCase):

    def setUp(self):
        # set up a zodb
        # we can't use DemoStorage here 'cos it doesn't support History
        self.dir = tempfile.mkdtemp()
        fs_path = os.path.join(self.dir, 'testHistory.fs')
        self.s = FileStorage(fs_path, create=True)
        self.connection = ZODB.DB(self.s).open()
        r = self.connection.root()
        a = Application()
        r['Application'] = a
        self.root = a
        # create a python script
        a['test'] = HistoryItem()
        self.hi = hi = a.test
        # commit some changes
        hi.title = 'First title'
        t = transaction.get()
        # undo note made by Application instantiation above.
        t.description = None
        t.note(u'Change 1')
        t.commit()
        time.sleep(0.02)  # wait at least one Windows clock tick
        hi.title = 'Second title'
        t = transaction.get()
        t.note(u'Change 2')
        t.commit()
        time.sleep(0.02)  # wait at least one Windows clock tick
        hi.title = 'Third title'
        t = transaction.get()
        t.note(u'Change 3')
        t.commit()

    def tearDown(self):
        # get rid of ZODB
        transaction.abort()
        self.connection.close()
        self.s.close()
        del self.root
        del self.connection
        del self.s
        shutil.rmtree(self.dir)

    def test_manage_change_history(self):
        r = self.hi.manage_change_history()
        self.assertEqual(len(r), 3)  # three transactions
        for i in range(3):
            entry = r[i]
            # check no new keys show up without testing
            self.assertEqual(len(entry.keys()), 6)
            # the transactions are in newest-first order
            self.assertEqual(entry['description'], 'Change %i' % (3 - i))
            self.assertTrue('key' in entry)
            # lets not assume the size will stay the same forever
            self.assertTrue('size' in entry)
            self.assertTrue('tid' in entry)
            self.assertTrue('time' in entry)
            if i:
                # check times are increasing
                self.assertTrue(entry['time'] < r[i - 1]['time'])
            self.assertEqual(entry['user_name'], '')

    def test_manage_historyCopy(self):
        # we assume this works 'cos it's tested above
        r = self.hi.manage_change_history()
        # now we do the copy
        self.hi.manage_historyCopy(keys=[r[2]['key']])
        # do a commit, just like ZPublisher would
        transaction.commit()
        # check the body is as it should be, we assume
        # (hopefully not foolishly)
        # that all other attributes will behave the same
        self.assertEqual(self.hi.title,
                         'First title')
Exemple #43
0
def _gen_testdb(outfs_path, zext):
    xtime_reset()

    ext = ext4subj
    if not zext:
        def ext(subj): return {}

    logging.basicConfig()

    # generate random changes to objects hooked to top-level root by a/b/c/... key
    random.seed(0)

    namev = [_ for _ in "abcdefg"]
    Niter = 2
    for i in range(Niter):
        stor = FileStorage(outfs_path, create=(i == 0))
        db   = DB(stor)
        conn = db.open()
        root = conn.root()
        assert root._p_oid == p64(0), repr(root._p_oid)

        for j in range(25):
            name = random.choice(namev)
            if name in root:
                obj = root[name]
            else:
                root[name] = obj = Object(None)

            obj.value = "%s%i.%i" % (name, i, j)

            commit(u"user%i.%i" % (i,j), u"step %i.%i" % (i, j), ext(name))

        # undo a transaction one step before a latest one a couple of times
        for j in range(2):
            # XXX undoLog, despite what its interface says:
            #   https://github.com/zopefoundation/ZODB/blob/2490ae09/src/ZODB/interfaces.py#L472
            # just returns log of all transactions in specified range:
            #   https://github.com/zopefoundation/ZODB/blob/2490ae09/src/ZODB/FileStorage/FileStorage.py#L1008
            #   https://github.com/zopefoundation/ZODB/blob/2490ae09/src/ZODB/FileStorage/FileStorage.py#L2103
            # so we retry undoing next log's txn on conflict.
            for ul in db.undoLog(1, 20):
                try:
                    db.undo(ul["id"])
                    commit(u"root%i.%i\nYour\nMagesty " % (i, j),
                           u"undo %i.%i\nmore detailed description\n\nzzz ..." % (i, j) + "\t"*(i+j),
                           ext("undo %s" % ul["id"]))
                except UndoError:
                    transaction.abort()
                    continue

                break

        # delete an object
        name = random.choice(list(root.keys()))
        obj = root[name]
        root[name] = Object("%s%i*" % (name, i))
        # NOTE user/ext are kept empty on purpose - to also test this case
        commit(u"", u"predelete %s" % unpack64(obj._p_oid), {})

        # XXX obj in db could be changed by above undo, but ZODB does not automatically
        # propagate undo changes to live objects - so obj._p_serial can be stale.
        # Get serial via history.
        obj_tid_lastchange = db.history(obj._p_oid)[0]['tid']

        txn = precommit(u"root%i\nYour\nRoyal\nMagesty' " % i +
                            ''.join(chr(_) for _ in range(32)),     # <- NOTE all control characters
                        u"delete %i\nalpha beta gamma'delta\"lambda\n\nqqq ..." % i,
                        ext("delete %s" % unpack64(obj._p_oid)))
        # at low level stor requires ZODB.IStorageTransactionMetaData not txn (ITransaction)
        txn_stormeta = TransactionMetaData(txn.user, txn.description, txn.extension)
        stor.tpc_begin(txn_stormeta)
        stor.deleteObject(obj._p_oid, obj_tid_lastchange, txn_stormeta)
        stor.tpc_vote(txn_stormeta)
        # TODO different txn status vvv
        # XXX vvv it does the thing, but py fs iterator treats this txn as EOF
        #if i != Niter-1:
        #    stor.tpc_finish(txn_stormeta)
        stor.tpc_finish(txn_stormeta)

        # close db & rest not to get conflict errors after we touched stor
        # directly a bit. everything will be reopened on next iteration.
        conn.close()
        db.close()
        stor.close()
Exemple #44
0
        self.balance -= amount


storage = FileStorage('Data.fs')
db = DB(storage)
connection = db.open()
root = connection.root()

acclist=[]
a1=Account()
acclist.append(a1)
a2=Account()
acclist.append(a2)
a2.deposit(10.0)
acclist1=copy.deepcopy(acclist)
root['account-list-1']=acclist1
transaction.commit()

acclist[0].deposit(11.0)
acclist[1].deposit(22.0)
a3 = Account()
acclist.append(a3)
acclist2=copy.deepcopy(acclist)
root['account-list-2']=acclist2
transaction.commit()

connection.close()
db.close()
storage.close()

class ATHistoryAwareTests(unittest.TestCase):
    def setUp(self):
        # Set up a ZODB and Application object. We can't use DemoStorage
        # as it doesn't support the history() API.
        self._dir = tempfile.mkdtemp()
        self._storage = FileStorage(
            os.path.join(self._dir, 'test_athistoryaware.fs'),
            create=True)
        self._connection = ZODB.DB(self._storage).open()
        root = self._connection.root()
        root['Application'] = OFS.Application.Application()
        self.app = root['Application']

        # Our basic testing object
        self.app.object = DummyObject()
        self.object = self.app.object
        t = transaction.get()
        t.description = None  # clear initial transaction note
        t.note('Transaction 1')
        t.setUser('User 1')
        t.commit()

        # Alter object and annotations over several transactions
        annotations = self.object.__annotations__
        self.object.foo = 'baz'
        annotations[KEY1].spam = 'python'
        t = transaction.get()
        t.note('Transaction 2')
        t.setUser('User 2')
        t.commit()

        annotations[KEY3] = DummyAnnotation()
        t = transaction.get()
        t.note('Transaction 3')
        t.setUser('User 3')
        t.commit()

        del annotations[KEY3]
        annotations[KEY2].spam = 'lumberjack'
        t = transaction.get()
        t.note('Transaction 4')
        t.setUser('User 4')
        t.commit()

        self.object.foo = 'mit'
        annotations[KEY1].spam = 'trout'
        t = transaction.get()
        t.note('Transaction 5')
        t.setUser('User 5')
        t.commit()

    def tearDown(self):
        transaction.abort()
        del self.app
        self._connection.close()
        del self._connection
        self._storage.close()
        del self._storage
        shutil.rmtree(self._dir)

    def test_historyMetadata(self):
        """Each revision entry has unique metadata"""
        for i, entry in enumerate(self.object.getHistories()):
            # History is returned in reverse order, so Transaction 5 is first
            self.assertEqual(entry[2], 'Transaction %d' % (5 - i))
            self.assertEqual(entry[3], '/ User %d' % (5 - i))

    def test_objectContext(self):
        """Objects are returned with an acquisition context"""
        for entry in self.object.getHistories():
            self.assertEqual(entry[0].aq_parent, self.app)

    def test_simpleAttributes(self):
        """Simple, non-persistent attributes are tracked"""
        foo_history = (e[0].foo for e in self.object.getHistories())
        expected = ('mit', 'baz', 'baz', 'baz', 'bar')
        self.assertEqual(tuple(foo_history), expected)

    def test_annotation(self):
        """Persistent subkeys of the __annotations__ object"""
        key1_history = (e[0].__annotations__[KEY1].spam
                        for e in self.object.getHistories())
        expected = ('trout', 'python', 'python', 'python', 'eggs')
        self.assertEqual(tuple(key1_history), expected)

        key2_history = (e[0].__annotations__[KEY2].spam
                        for e in self.object.getHistories())
        expected = ('lumberjack', 'lumberjack', 'eggs', 'eggs', 'eggs')
        self.assertEqual(tuple(key2_history), expected)

    def test_annotationlifetime(self):
        """Addition and deletion of subkeys is tracked"""
        key3_history = (bool(KEY3 in e[0].__annotations__)
                        for e in self.object.getHistories())
        expected = (False, False, True, False, False)
        self.assertEqual(tuple(key3_history), expected)

    def test_maxReturned(self):
        history = list(self.object.getHistories(max=2))
        self.assertEqual(len(history), 2)
class TestBase(unittest.TestCase):
    loglist = []

    def setUp(self):
        """
        (based on ZODB.ConflictResolution.txt): Create the database for the 
        tests Set the databases. 
        Think of `conn_A` (connection A) as one thread, and `conn_B` 
        (connection B) as a concurrent thread.
        """

        self.testdir = tempfile.mkdtemp()
        self.storage = FileStorage(os.path.join(self.testdir, 'Data.fs'))
        self.db = ZODB.DB(self.storage)

        self.tm_A = transaction.TransactionManager()
        self.conn_A = self.db.open(transaction_manager=self.tm_A)
        p_ConnA = self.conn_A.root()['p'] = PCounter()
        self.tm_A.commit()

        self.tm_B = transaction.TransactionManager()
        self.conn_B = self.db.open(transaction_manager=self.tm_B)
        p_ConnB = self.conn_B.root()['p']
        assert p_ConnA._p_oid == p_ConnB._p_oid

        self.tm_C = transaction.TransactionManager()
        self.conn_C = self.db.open(transaction_manager=self.tm_C)
        p_ConnC = self.conn_B.root()['p']
        assert p_ConnA._p_oid == p_ConnC._p_oid
 
    def tearDown(self):
        """ close and delete.
        """
        self.db.close()
        self.storage.close()
        removeDirectory(self.testdir)

    def getLog(self, continue_from_here=""):
        """ Read the log file.
        """
        self.logCE.handlers[0].flush()
        #f = open(self.logCE.handlers[0].baseFilename, "r")
        f = open(self.logfile, "r")
        text = f.read()
        f.close()
        return text[len(continue_from_here):]

    def configureCE(self,
                    CELogger_LOGFILE='conflict_error_test.log',
                    CELogger_FIRST_CHANGE_ONLY=True,
                    CELogger_RAISE_CONFLICTERRORPREVIEW=False,
                    CELogger_ACTIVE=True):
        """ configure ClinflictErrorLooger
        """
        self.logfile = os.path.join(self.testdir, CELogger_LOGFILE)
        self.logCE = do_enable(self.logfile)
        self.logCE.level = logging.DEBUG
        conflictLogger.config(
                 log=self.logCE,
                 FIRST_CHANGE_ONLY=CELogger_FIRST_CHANGE_ONLY,
                 RAISE_CONFLICTERRORPREVIEW=CELogger_RAISE_CONFLICTERRORPREVIEW)
Exemple #47
0
    for o, v in opts:
        if o == '-v':
            VERBOSE += 1
        if o == '-f':
            FSPATH = v
        if o == '-t':
            TXN_INTERVAL = int(v)
        if o == '-p':
            PACK_INTERVAL = int(v)
        if o == '-n':
            LIMIT = int(v)
#        if o == '-T':
#            INDEX = make_old_index

    if len(args) != 1:
        print "Expected on argument"
        print __doc__
        sys.exit(2)
    dir = args[0]

    fs = FileStorage(FSPATH)
    db = ZODB.DB(fs)
    cn = db.open()
    rt = cn.root()
    dir = os.path.join(os.getcwd(), dir)
    print dir
    main(db, rt, dir)
    cn.close()
    fs.close()
Exemple #48
0
class DatabaseProvider(DefaultProvider):

    def __init__(self, path, backoff = None):
        """
        Initializes DatabaseProvider.

        This provider requires closing database after using (call close function).

        :param path: path to a database file
        :type path: str
        :param backoff: (optional) backoff provider
        """
        DefaultProvider.__init__(self, backoff)

        self.storage = FileStorage(path + '.fs')
        self.db = DB(self.storage)
        self.connection = self.db.open()
        self.root = self.connection.root()

        if not self.root:
            self.__dictionary_init()

    def __check_connection(self):
        if self.root == None:
            raise LookupError("Database connection is closed!")

    def close(self):
        """
        Function close connection to database.

        Call this before destroying DatabaseProvider object to avoid issues with database file access.
        """
        self.connection.close()
        self.db.close()
        self.storage.close()
        self.root = None

    def save_model(self, conf):
        """
        Inserts new data into database.

        Get new data using WikiProvider and get it using get_model method.

        :param conf: new data returned by WikiProvider get_model method
        """
        self.__check_connection();

        for type in conf:
            for baseword in  conf[type]:
                self.__save(conf[type][baseword], baseword, type)

    def _get_word(self, conf):
        '''
        Returns word or throw KeyError, if there is no information
        about word in database
        '''
        self.__check_connection();
        return self.__get_word(conf[2], conf[1])

    def _get_conf(self, word):
        '''
        Returns word configuration or KeyError, if there is no
        information about word in database
        '''
        self.__check_connection();
        return self.__get_conf_preview(word)

    def __dictionary_init(self):
        '''
           Initialization of database dictionaries.
        '''
        self.root['przymiotnik'] = PersistentMapping()
        self.root['rzeczownik'] = PersistentMapping()
        self.root['czasownik'] = PersistentMapping()
        self.root['czasownik']['word'] = PersistentMapping()
        self.root['przymiotnik']['word'] = PersistentMapping()
        self.root['rzeczownik']['word'] = PersistentMapping()
        transaction.commit()

    def __save(self, dict, base_word, type):
        '''
            Save object to database in Bartosz Alchimowicz convention
        '''
        self.root[type]['word'][base_word] = dict
        transaction.commit()

    def __get_conf(self, base_word):
        '''
            Get configuration of word whic is in database
        '''
        for word_type in ['rzeczownik', 'czasownik', 'przymiotnik']:
            for word in self.root[word_type]['word'].keys():
                if word == base_word:
                    return self.root[word_type]['word'][word]

        raise KeyError("There is no such a word in Database")

    def __get_conf_preview(self, word):

        # rzeczownik
        dictionary = self.root['rzeczownik']['word']

        for base_word in dictionary.keys():
            for przypadek in dictionary[base_word]['przypadek'].keys():
                for liczba in dictionary[base_word]['przypadek'][przypadek]['liczba'].keys():
                    if dictionary[base_word]['przypadek'][przypadek]['liczba'][liczba] == word:
                        return [('rzeczownik', base_word,
                                {'przypadek' : przypadek,
                                 'liczba' : liczba })]
        # przymiotnik
        dictionary = self.root['przymiotnik']['word']

        for base_word in dictionary.keys():
            for stopien in dictionary[base_word]['stopień'].keys():
                for przypadek in dictionary[base_word]['stopień'][stopien]['przypadek'].keys():
                    for liczba in dictionary[base_word]['stopień'][stopien]['przypadek'][przypadek]['liczba'].keys():
                        for rodzaj in dictionary[base_word]['stopień'][stopien]['przypadek'][przypadek]['liczba'][liczba]['rodzaj'].keys():
                            if dictionary[base_word]['stopień'][stopien]['przypadek'][przypadek]['liczba'][liczba]['rodzaj'][rodzaj] == word:
                                return [('przymiotnik', base_word,
                                        {'stopień' : stopien,
                                         'liczba' : liczba,
                                         'rodzaj' : rodzaj})]
        # czasownik
        dictionary = self.root['czasownik']['word']

        for base_word in dictionary.keys():
            for aspekt in dictionary[base_word]['aspekt'].keys():
                for forma in dictionary[base_word]['aspekt'][aspekt]['forma'].keys():
                    for liczba in dictionary[base_word]['aspekt'][aspekt]['forma'][forma]['liczba'].keys():
                        for osoba in dictionary[base_word]['aspekt'][aspekt]['forma'][forma]['liczba'][liczba]['osoba'].keys():
                            if forma == 'czas przeszły':
                                for rodzaj in dictionary[base_word]['aspekt'][aspekt]['forma'][forma]['liczba'][liczba]['osoba'][osoba]['rodzaj'].keys():
                                    if dictionary[base_word]['aspekt'][aspekt]['forma'][forma]['liczba'][liczba]['osoba'][osoba]['rodzaj'][rodzaj] == word:
                                        return [('czasownik', base_word,
                                                {'aspekt' : aspekt,
                                                'forma' : forma,
                                                'liczba' : liczba,
                                                'osoba' : osoba,
                                                'rodzaj' : rodzaj})]
                            else:
                                if dictionary[base_word]['aspekt'][aspekt]['forma'][forma]['liczba'][liczba]['osoba'][osoba] == word:
                                        return [('czasownik', base_word,
                                                {'aspekt' : aspekt,
                                                'forma' : forma,
                                                'liczba' : liczba,
                                                'osoba' : osoba})]
        raise LookupError("configuration not found")


    def __get_word(self, conf, base_word):
        '''
            Search all database and get word
        '''
        try:
            return self.root['rzeczownik']['word'][base_word]['przypadek'][conf['przypadek']]['liczba'][conf['liczba']]
        except KeyError:
            try:
                return self.root['przymiotnik']['word'][base_word]['stopień'][conf['stopień']]['przypadek'][conf['przypadek']]['liczba'][conf['liczba']]['rodzaj'][conf['rodzaj']]
            except KeyError:
                try:
                    if conf['forma'] == 'czas teraźniejszy':
                        return self.root['czasownik']['word'][base_word]['aspekt'][conf['aspekt']]['forma'][conf['forma']]['liczba'][conf['liczba']]['osoba'][conf['osoba']]
                    else:
                        return self.root['czasownik']['word'][base_word]['aspekt'][conf['aspekt']]['forma'][conf['forma']]['liczba'][conf['liczba']]['osoba'][conf['osoba']]['rodzaj'][conf['rodzaj']]
                except KeyError:
                    raise KeyError("There is no such word in Database")
Exemple #49
0
class Objectbase:
	'''
		Class Objectbase - interface to an object database
	'''
	def __init__( self, dbfile = None ):
		'''
			Initializes the connection to the database located in 'dbfile'. If 
			'dbfile' not supplied, only initalizes and sets opened to False. If 
			database does not exist, creates a new one with the name 'dbfile'
		'''
		if dbfile:
			self.st     = FileStorage( dbfile )
			self.db     = DB( self.st )
			self.cn     = self.db.open()
			self.root   = self.cn.root()
			self.opened = True
		else:
			self.opened = False
			
	def selectByName( self, name ):
		'''
			Returns the object with the key name. If database doesn't contain such an
			object raises ValueError
		'''
		if self.opened:
			if self.root[ name ]:
				return self.root[ name ]
			else:
				error = 'This database has no object with the name ' + name
				raise ValueError, error
		else:
			error = 'Database is closed'
			raise ValueError, error

	def selectByType( self, typ ):
		'''
			Searches the database for a given type and returns a dictionary with
			key:object pairs which match this type. If database doesn't contain any
			object of the given type, returns an empty dictionary.
		'''
		if self.opened:
			objects = {}
			for i in self.root:
				if issubclass( type( self.root[ i ] ), typ ):
					objects[ i ] = self.root[ i ]
			return objects
		else:
			error = 'Database is closed'
			raise ValueError, error
		
	def delete( self, name, cascade = False ):
		'''
			Deletes an object from the database by it's key (name). If database 
			doesn't contain such an object raises ValueError. Parameter cascade is
			used (if True) to delete all relations of this object in other objects.
			Cascade option is not well implemented since it does not roll back the 
			transaction if somewhere in the process an error is raised (e.g. if
			a '1' side of a 1:1 or 1:N relation is being tried to remove). Use it only
			if the object you try to remove does not engage 1:1 relations or is the 
			'1' side of a 1:N relation.
		'''
		error = ''
		if self.opened:
			try:
				if not cascade:
					if False: # NOTE implement a way to check if a object has relations to other objects
						error = 'This object has relations to other objects in the objectbase'
						raise ValueError, error
					else:
						del self.root[ name ]
				else:
					for i in self.root[ name ].relations:
						for j in i[ 0 ]:
							j.removeRelation( self.root[ name ] )
					del self.root[ name ]
			except:
				if error == '':
					error = 'This database has no object with the name ' + name
				print error
				raise ValueError, error
		else:
			error = 'Database is closed'
			raise ValueError, error
	
	def update( self, name, object ):
		'''
			Updates the object with the key 'name' to match the value of 'object'. If
			database doesn't contain such an object raises ValueError. It also raises
			a ValueError if the object is not a subclass of the original object in the
			database.
		'''
		if self.opened:
			try:
				if issubclass( type( self.root[ name ] ), type( object ) ):
					del self.root[ name ]
					self.root[ name ] = object
				else:
					error = 'You cannot change the type of the object ' + name + ' to ' + str( type( object ) )
					raise ValueError, error
			except:
				error = 'This database has no object with the name ' + name
				raise ValueError, error
		else:
			error = 'Database is closed'
			raise ValueError, error
	
	def insert( self, name, object ):
		'''
			Inserts a new object with the key 'name' to the database.
		'''
		if self.opened:
			self.root[ name ] = object
		else:
			error = 'Database is closed'
			raise ValueError, error
	
	def close( self ):
		'''
			Closes the database.
		'''
		if self.opened:
			transaction.commit()
			self.cn.close()
			self.db.close()
			self.st.close()
			self.opened = False
		else:
			error = 'Database is allready closed'
			raise ValueError, error
		
	def open( self, dbfile ):
		'''
			Opens a database stored in 'dbfile' or creates a new one.
		'''
		if not self.opened:
			self.st     = FileStorage( dbfile )
			self.db     = DB( self.st )
			self.cn     = self.db.open()
			self.root   = self.cn.root()
			self.opened = True		
		else:
			error = 'Database is allready open'
			raise ValueError, error
Exemple #50
0
class node2vec:
    r_deleted = {}
    sentences = {}
    sentences_array = []
    degree = []
    r_types = []
    n_types = []
    r_types_d = []
    r_desv = {}
    n_types_d = []
    m_vectors = []
    m_points = []
    angle_matrix = []
    plotw = 800
    ploth = 500
    mode = "normal"

    def __init__(self, bd, port, user, pss, label, ns, nd, l, m, traversals,
                 iteraciones):
        self.nodes = []
        self.ndim = nd
        self.bd = bd
        self.port = port
        self.user = user
        self.pss = pss
        self.label = label
        self.ns = ns
        self.w_size = l
        self.mode = m
        self.iteraciones = iteraciones

        # Setting up Neo4j DB
        neo4j.authenticate("http://*****:*****@localhost:" +
                                                   str(self.port) +
                                                   "/db/data/")
        batches = 100

        if not os.path.exists("models/" + self.bd +
                              ".npy") or not os.path.exists("models/" +
                                                            self.bd +
                                                            "l-degree.npy"):
            print "Conecting to BD..."
            nn = neo4j.CypherQuery(
                self.graph_db, "match n return count(n) as cuenta1").execute()
            self.numnodes = nn[0].cuenta1
            self.sentences_array = []
            nb = float(self.numnodes / batches)
            count = -1
            self.degree = []
            for i in range(1, int(nb) + 1):
                count += 1
                consulta = "match (n)-[r]-(m) where n." + self.label + " <> '' return n,count(r) as d, n." + self.label + ", collect(m." + self.label + ") as collect skip " + str(
                    batches * count) + " limit " + str(batches)
                cuenta = neo4j.CypherQuery(self.graph_db, consulta).execute()
                print "\r" + str(float((i / nb) * 100)) + "%"
                for cuenta1 in cuenta:
                    name = cuenta1['n.' + label].replace(" ", "_")
                    context = []
                    #Extracting context(relations)
                    for s in cuenta1['collect']:
                        if type(s) is list:
                            for x in s:
                                context.append(str(x).replace(" ", "_"))
                        else:
                            if s:
                                context.append(str(s).replace(" ", "_"))
                #Extracting contexto(properties)
                    for t in cuenta1['n']:
                        s = cuenta1['n'][t]
                        if type(s) is list:
                            for x in s:
                                context.append(str(x).replace(" ", "_"))
                        else:
                            if s:
                                context.append(str(s).replace(" ", "_"))
                    if len(context) >= l - 1 and cuenta1.d is not None:
                        sentence = context
                        sentence.insert(0, name)
                        self.sentences_array.append(sentence)
                        self.degree.append(cuenta1.d)

            np.save("models/" + self.bd, self.sentences_array)
            np.save("models/" + self.bd + "l-degree", self.degree)
        else:
            self.sentences_array = np.load("models/" + self.bd + ".npy")
            self.degree = np.load("models/" + self.bd + "l-degree.npy")
        for s in self.sentences_array:
            self.sentences[s[0]] = s[1:]
        print "models/" + self.bd + ".npy"

    def learn(self, m, ts, d, it):
        num_cores = multiprocessing.cpu_count()
        print "numCores = " + str(num_cores)
        self.path = "models/" + self.bd + str(self.ndim) + "d-" + str(
            self.ns) + "w" + str(self.w_size) + "l" + m
        if d:
            #el metodo delete_rels elimina las relaciones por las que despues preguntaremos de self.sentences_array antes de entrenar y devuelve el nuevo dump con las relaciones quitadas y una lista de las relaciones quitadas
            self.learn(m, 0, False, it)
            self.get_rels([])
            sents, self.r_deleted = delete_rels(self.sentences_array,
                                                self.r_types, ts)
            self.path = self.path + "del" + str(ts) + "-"
        else:
            sents = self.sentences_array
        self.path = self.path + str(it) + ".npy"
        print "Learning:" + self.path
        print "CCCC!"
        if not os.path.exists(self.path):
            print "Entra"
            entrada = []
            results = Parallel(n_jobs=num_cores, backend="threading")(
                delayed(generate_sample)(self.mode, sents, self.degree,
                                         self.w_size, i)
                for i in range(1, self.ns))
            for r in results:
                entrada.append(r)
            self.w2v = word2vec.Word2Vec(entrada,
                                         size=self.ndim,
                                         window=self.w_size,
                                         min_count=1,
                                         workers=num_cores,
                                         sg=0)
            self.w2v.save(self.path)
            print "TERMINO"
        else:
            self.w2v = word2vec.Word2Vec.load(self.path)
        self.get_nodes()
        self.get_rels([])
        self.delete_props()

    def get_rels(self, traversals):
        if not os.path.exists("models/" + self.bd + "-trels.p"):
            f = open("models/" + self.bd + "-trels.p", "w")
            consulta = neo4j.CypherQuery(
                self.graph_db, "match (n)-[r]->(m) return n." + self.label +
                " as s,m." + self.label +
                " as t ,r,type(r) as tipo,labels(m) as tipot").execute()
            todas = []
            for c in consulta:
                todas.append([c.s, c.tipo, c.t, c.tipot])
            pickle.dump(todas, f)
        else:
            f = open("models/" + self.bd + "-trels.p", "r")
            todas = pickle.load(f)
        links = dict()
        for l in todas:
            link = dict()
            if l[0] and l[1] and l[2]:
                link["tipo"] = l[1]
                link["s"] = l[0].replace(" ", "_")
                link["t"] = l[2].replace(" ", "_")
                link["tipot"] = l[3][0].replace(" ", "_")
                if link["s"] in self.w2v and link["t"] in self.w2v:
                    link["v"] = self.w2v[link["t"]] - self.w2v[link["s"]]
                    if not link["tipo"] in links:
                        links[link["tipo"]] = []
                    links[link["tipo"]].append(link)
        self.r_types = links

    def r_analysis(self):
        print "Relation Types Analysis"
        if self.r_types == []:
            self.get_rels()
        self.m_vectors = {}
        for t in self.r_types:
            vectors = []
            rels = self.r_types[t]
            for r in rels:
                if (r["s"] in self.w2v) and (r["t"] in self.w2v):
                    vectors.append(self.w2v[r["t"]] - self.w2v[r["s"]])
            vector_medio = np.mean(vectors, axis=0)
            self.m_vectors[t] = np.mean(vectors, axis=0)
            media = 0
            for v in vectors:
                media = media + angle(v, vector_medio)
            media = media / len(vectors)
            self.r_desv[t] = media
        print "Mean Vector Angles"
        self.angle_matrix = dict()
        for i, t in enumerate(self.r_types):
            self.angle_matrix[t] = dict()
            for j, x in enumerate(self.r_types):
                self.angle_matrix[t][x] = angle(self.m_vectors[t],
                                                self.m_vectors[x])
                if x not in self.angle_matrix:
                    self.angle_matrix[x] = dict()
                self.angle_matrix[x][t] = angle(self.m_vectors[t],
                                                self.m_vectors[x])

    def get_nodes(self):
        if not os.path.exists("models/" + self.bd + "-tnodes.p"):
            f = open("models/" + self.bd + "-tnodes.p", "w")
            consulta = neo4j.CypherQuery(
                self.graph_db, "match (n) return n." + self.label +
                " as name,labels(n) as tipos").execute()
            nodes = dict()
            for node in consulta:
                if node.name and node.tipos <> []:
                    name = node.name.replace(" ", "_")
                    for tipo in node.tipos:
                        if not tipo in nodes:
                            nodes[tipo] = []
                        nodes[tipo].append(name)
            self.n_types = nodes
            pickle.dump(nodes, f)
        else:
            f = open("models/" + self.bd + "-tnodes.p", "r")
            self.n_types = pickle.load(f)

    def n_analysis(self):
        print "Node Type Analysis"
        if self.n_types == []:
            self.get_nodes()
        self.m_points = dict()
        self.n_types_d = dict()
        for nt in self.n_types:
            points = []
            for node in self.n_types[nt]:
                if node in self.w2v:
                    points.append(self.w2v[node])
            if len(points) > 0:
                punto_medio = [0] * len(points[0])

                for p in points:
                    for idx, d in enumerate(p):
                        punto_medio[idx] = punto_medio[idx] + d
                for idx, d in enumerate(punto_medio):
                    punto_medio[idx] = punto_medio[idx] / len(points)
                if nt not in self.m_points:
                    self.m_points[nt] = punto_medio
                #print "-------------------"+nt+"-------------------"
                #print "Number of Nodes: "+ str(len(points))
                dev = 0
                for p in points:
                    dev = dev + scipy.spatial.distance.euclidean(
                        punto_medio, p)**2
                dev = math.sqrt((dev / len(points)))

                #print "Standard Deviation:"+str(dev)
                if nt not in self.n_types_d:
                    self.n_types_d[nt] = dev
            #print "Variance:"+str(np.var(points))

        #print "Distancia entre los puntos medios"
        #for i,t in enumerate(self.m_points):
        #for j,x in enumerate(self.m_points):
        #if i <> j:
        #print t+" vs. "+x
        #print scipy.spatial.distance.euclidean(self.m_points[t] , self.m_points[x])

    def analysis(self):
        self.n_analysis()
        self.r_analysis()

    def similares(self, nodo, positives, negatives, top_n, filtrado):
        #Version nueva: utilizo las estructuras nodes_pos y nodes_type en un knn de scikit
        clf = neighbors.KNeighborsClassifier(
            top_n, "uniform", n_jobs=multiprocessing.cpu_count())
        clf.fit(self.nodes_pos, self.nodes_type)
        my_list = clf.kneighbors(positives[0], top_n, False)

        #Version antigua: usaba word2ec por lo que estaba trabajando con todas las propiedades
        my_list = self.w2v.most_similar(positives, negatives, topn=top_n)
        result = []
        for m in my_list:
            if m[0] != nodo:
                result.append(m)
        return result

    def predice(self, nodo, rel, fast, top_n, filtrado):
        if not fast:
            votos = []
            for r in self.r_types[rel]:
                other = r["s"]
                if (r["s"] == nodo):
                    other = r["t"]
                p2 = neo4j.CypherQuery(
                    self.graph_db, "match (n)-[:" + rel[0] + "]-(m) where n." +
                    label + ' = "' + other + '" return m.' + label).execute()
                print p2
                if len(p2) > 0:
                    for p in p2:
                        prop2 = p["m." + label]
                    prop2 = prop2.replace(" ", "_")
                    other = other.replace(" ", "_")
                    if other in self.w2v and prop2 in self.w2v:
                        prop1 = self.similares([nodo, other], [prop2])[0][0]
                        votos.append(prop1)
            return max(set(votos), key=votos.count)
        if fast:
            sim = self.similares(nodo,
                                 [self.w2v[nodo] + self.m_vectors[str(rel)]],
                                 [], top_n, filtrado)
            f = []
            for s in sim:
                f.append(s[0])
            if len(f) > 0:
                return f
            else:
                return ""

    def aciertos_rel(self, rel, label, fast, string):
        print "jeje"
        if not os.path.exists("models/" + self.bd + str(self.ndim) + "d-" +
                              str(self.ns) + "w" + str(self.w_size) +
                              self.mode + "-lpr-" + rel + string + ".p"):
            print "ta"
            parcial = 0
            total = 0
            cuenta_misc = 0
            for d in self.r_deleted[rel]:
                print "analizando relacion"
                print rel
                rs = d["s"]
                cuenta_misc += 1
                print rs
                print rs in self.w2v
                print rs in self.sentences
                if rs in self.w2v and not '"' in rs:
                    total = total + 1
                    nbs = self.predice(rs, label, self.r_types1[rel]["t"], rel,
                                       fast)
                    if d["t"] in nbs:
                        print "HOLA"
                        print d["t"]
                        print nbs.index(d["t"])
                        parcial += float(1 / float(nbs.index(d["t"]) + 1))
                    print parcial
                    print total
            if total > 0:
                result = float(parcial) / float(total)
            else:
                result = 0
            f = open(
                "models/" + self.bd + str(self.ndim) + "d-" + str(self.ns) +
                "w" + str(self.w_size) + self.mode + "-lpr-" + rel + string +
                ".p", "w")
            pickle.dump(result, f)
        else:
            f = open(
                "models/" + self.bd + str(self.ndim) + "d-" + str(self.ns) +
                "w" + str(self.w_size) + self.mode + "-lpr-" + rel + string +
                ".p", "r")
            result = pickle.load(f)
        return result

    def link_prediction_ratio(self):
        ratiosf = {}
        for r in self.r_types:
            ratiosf[r] = self.aciertos_rel(r, self.label, True)

        xname = []
        yname = []
        alpha = []
        color = []
        ratio = []
        names = []
        for r in self.r_types:
            names.append(r)
            xname.append(r)
            yname.append("Ratio")
            alpha.append(ratiosf[r] / 100)
            ratio.append(ratiosf[r])
            color.append('black')
        source = ColumnDataSource(data=dict(
            xname=xname, yname=yname, colors=color, alphas=alpha,
            ratios=ratio))
        p = figure(title="Link Prediction Ratios",
                   x_axis_location="above",
                   tools="resize,hover,save",
                   x_range=xname,
                   y_range=["Ratio"])
        p.rect('xname',
               'yname',
               0.9,
               0.9,
               source=source,
               color='colors',
               alpha='alphas',
               line_color=None)
        p.grid.grid_line_color = None
        p.axis.axis_line_color = None
        p.axis.major_tick_line_color = None
        p.axis.major_label_text_font_size = "5pt"
        p.axis.major_label_standoff = 0
        p.xaxis.major_label_orientation = np.pi / 3
        hover = p.select(dict(type=HoverTool))
        hover.tooltips = OrderedDict([
            ('link type and method', '@yname, @xname'),
            ('link prediction ratio', '@ratios'),
        ])
        return p

    def ntype(self, n):
        for t in self.n_types:
            if n in self.n_types[t]:
                return t

    def connectZODB(self):
        print "connnecting"
        if not os.path.exists(self.bd + '.fs'):
            self.storage = FileStorage(self.bd + '.fs')
            self.db = DB(self.storage)
            self.connection = self.db.open()
            self.root = self.connection.root()
            self.root = PersistentDict()
        else:
            self.storage = FileStorage(self.bd + '.fs')
            self.db = DB(self.storage)
            self.connection = self.db.open()
            self.root = self.connection.root()

    def disconnectZODB(self):
        print "grabando!"
        transaction.commit()
        self.connection.close()
        self.db.close()
        self.storage.close()

    #Creating nodes_pos dictionary with only nodes vectors (avoiding properties representation) and nodes_target with the type of each node

    def delete_props(self):
        self.nodes_pos = []
        self.nodes_type = []
        self.nodes_name = []
        for t in self.n_types:
            for n in self.n_types[t]:
                if n in self.w2v:
                    self.nodes_pos.append(self.w2v[n])
                    self.nodes_type.append(t)
                    self.nodes_name.append(n)
        print len(self.nodes_pos)
        print len(self.nodes_type)
        self.nodes_pos = list(self.nodes_pos)
        self.nodes_type = list(self.nodes_type)

#Obtenemos el vector medio del traversal solicitado.

    def get_vtraversal(self, traversal):
        traversals = self.get_traversals(traversal, 1)
        total = 0
        suma = "INICIO"
        for t in traversals:
            if t["t"] in self.w2v and t["s"] in self.w2v:
                total += 1
                vector = self.w2v[t["t"]] - self.w2v[t["s"]]
                if suma == "INICIO":
                    suma = vector
                else:
                    suma += vector
        return suma / total

#Obtenemos una serie de traversals aleatorios del tipo indicado (tantos como indique 0<ts<1)

    def get_traversals(self, traversal, ts):
        if not os.path.exists("models/" + self.bd + "-trav-" + traversal +
                              ".p"):
            f = open("models/" + self.bd + "-trav-" + traversal + ".p", "w")
            consulta = neo4j.CypherQuery(
                self.graph_db, "match (n)" + traversal + "(m) return n." +
                self.label + " as s,m." + self.label +
                " as t ,labels(m) as tipot").execute()
            todas = []
            for c in consulta:
                todas.append({"s": c.s, "t": c.t, "tipot": c.tipot[0]})
            pickle.dump(todas, f)
        else:
            f = open("models/" + self.bd + "-trav-" + traversal + ".p", "r")
            todas = pickle.load(f)
        for t in todas:
            t["s"] = t["s"].replace(" ", "_")
            t["t"] = t["t"].replace(" ", "_")
            t["tipot"] = t["tipot"].replace(" ", "_")
        finales = random.sample(todas, int(len(todas) * ts))
        return finales

    def entity_retrieval(self, node, rel_type, target_t):
        temp_pos = []
        temp_name = []
        linkstopredictV = []
        self.r_analysis()
        for idx, e in enumerate(self.nodes_type):
            if e == target_t:
                temp_pos.append(self.nodes_pos[idx])
                temp_name.append(self.nodes_name[idx])
        if len(temp_pos) < 1000:
            ks = len(temp_pos)
        else:
            ks = 1000
        clasificador = neighbors.KNeighborsClassifier(
            ks, "uniform", n_jobs=multiprocessing.cpu_count())
        clasificador.fit(temp_pos, temp_name)
        linkstopredictV.append(self.w2v[node] + self.m_vectors[str(rel_type)])
        nbs = clasificador.kneighbors(linkstopredictV, ks, False)
        result = []
        for e in nbs[0]:
            result.append(temp_name[e])
        return result
class Indexer:

    filestorage = database = connection = root = None

    def __init__(self, datafs, writable=0, trans=0, pack=0):
        self.trans_limit = trans
        self.pack_limit = pack
        self.trans_count = 0
        self.pack_count = 0
        self.stopdict = get_stopdict()
        self.mh = mhlib.MH()
        self.filestorage = FileStorage(datafs, read_only=(not writable))
        self.database = DB(self.filestorage)
        self.connection = self.database.open()
        self.root = self.connection.root()
        try:
            self.index = self.root["index"]
        except KeyError:
            self.index = self.root["index"] = TextIndex()
        try:
            self.docpaths = self.root["docpaths"]
        except KeyError:
            self.docpaths = self.root["docpaths"] = IOBTree()
        try:
            self.doctimes = self.root["doctimes"]
        except KeyError:
            self.doctimes = self.root["doctimes"] = IIBTree()
        try:
            self.watchfolders = self.root["watchfolders"]
        except KeyError:
            self.watchfolders = self.root["watchfolders"] = {}
        self.path2docid = OIBTree()
        for docid in self.docpaths.keys():
            path = self.docpaths[docid]
            self.path2docid[path] = docid
        try:
            self.maxdocid = max(self.docpaths.keys())
        except ValueError:
            self.maxdocid = 0
        print len(self.docpaths), "Document ids"
        print len(self.path2docid), "Pathnames"
        print self.index.lexicon.length(), "Words"

    def dumpfreqs(self):
        lexicon = self.index.lexicon
        index = self.index.index
        assert isinstance(index, OkapiIndex)
        L = []
        for wid in lexicon.wids():
            freq = 0
            for f in index._wordinfo.get(wid, {}).values():
                freq += f
            L.append((freq, wid, lexicon.get_word(wid)))
        L.sort()
        L.reverse()
        for freq, wid, word in L:
            print "%10d %10d %s" % (wid, freq, word)

    def dumpwids(self):
        lexicon = self.index.lexicon
        index = self.index.index
        assert isinstance(index, OkapiIndex)
        for wid in lexicon.wids():
            freq = 0
            for f in index._wordinfo.get(wid, {}).values():
                freq += f
            print "%10d %10d %s" % (wid, freq, lexicon.get_word(wid))

    def dumpwords(self):
        lexicon = self.index.lexicon
        index = self.index.index
        assert isinstance(index, OkapiIndex)
        for word in lexicon.words():
            wid = lexicon.get_wid(word)
            freq = 0
            for f in index._wordinfo.get(wid, {}).values():
                freq += f
            print "%10d %10d %s" % (wid, freq, word)

    def close(self):
        self.root = None
        if self.connection is not None:
            self.connection.close()
            self.connection = None
        if self.database is not None:
            self.database.close()
            self.database = None
        if self.filestorage is not None:
            self.filestorage.close()
            self.filestorage = None

    def interact(self, nbest=NBEST, maxlines=MAXLINES):
        try:
            import readline
        except ImportError:
            pass
        text = ""
        top = 0
        results = []
        while 1:
            try:
                line = raw_input("Query: ")
            except EOFError:
                print "\nBye."
                break
            line = line.strip()
            if line.startswith("/"):
                self.specialcommand(line, results, top - nbest)
                continue
            if line:
                text = line
                top = 0
            else:
                if not text:
                    continue
            try:
                results, n = self.timequery(text, top + nbest)
            except KeyboardInterrupt:
                raise
            except:
                reportexc()
                text = ""
                continue
            if len(results) <= top:
                if not n:
                    print "No hits for %r." % text
                else:
                    print "No more hits for %r." % text
                text = ""
                continue
            print "[Results %d-%d from %d" % (top+1, min(n, top+nbest), n),
            print "for query %s]" % repr(text)
            self.formatresults(text, results, maxlines, top, top+nbest)
            top += nbest

    def specialcommand(self, line, results, first):
        assert line.startswith("/")
        line = line[1:]
        if not line:
            n = first
        else:
            try:
                n = int(line) - 1
            except:
                print "Huh?"
                return
        if n < 0 or n >= len(results):
            print "Out of range"
            return
        docid, score = results[n]
        path = self.docpaths[docid]
        i = path.rfind("/")
        assert i > 0
        folder = path[:i]
        n = path[i+1:]
        cmd = "show +%s %s" % (folder, n)
        if os.getenv("DISPLAY"):
            os.system("xterm -e  sh -c '%s | less' &" % cmd)
        else:
            os.system(cmd)

    def query(self, text, nbest=NBEST, maxlines=MAXLINES):
        results, n = self.timequery(text, nbest)
        if not n:
            print "No hits for %r." % text
            return
        print "[Results 1-%d from %d]" % (len(results), n)
        self.formatresults(text, results, maxlines)

    def timequery(self, text, nbest):
        t0 = time.time()
        c0 = time.clock()
        results, n = self.index.query(text, nbest)
        t1 = time.time()
        c1 = time.clock()
        print "[Query time: %.3f real, %.3f user]" % (t1-t0, c1-c0)
        return results, n

    def formatresults(self, text, results, maxlines=MAXLINES,
                      lo=0, hi=sys.maxint):
        stop = self.stopdict.has_key
        words = [w for w in re.findall(r"\w+\*?", text.lower()) if not stop(w)]
        pattern = r"\b(" + "|".join(words) + r")\b"
        pattern = pattern.replace("*", ".*") # glob -> re syntax
        prog = re.compile(pattern, re.IGNORECASE)
        print '='*70
        rank = lo
        qw = self.index.query_weight(text)
        for docid, score in results[lo:hi]:
            rank += 1
            path = self.docpaths[docid]
            score = 100.0*score/qw
            print "Rank:    %d   Score: %d%%   File: %s" % (rank, score, path)
            path = os.path.join(self.mh.getpath(), path)
            try:
                fp = open(path)
            except (IOError, OSError), msg:
                print "Can't open:", msg
                continue
            msg = mhlib.Message("<folder>", 0, fp)
            for header in "From", "To", "Cc", "Bcc", "Subject", "Date":
                h = msg.getheader(header)
                if h:
                    print "%-8s %s" % (header+":", h)
            text = self.getmessagetext(msg)
            if text:
                print
                nleft = maxlines
                for part in text:
                    for line in part.splitlines():
                        if prog.search(line):
                            print line
                            nleft -= 1
                            if nleft <= 0:
                                break
                    if nleft <= 0:
                        break
            print '-'*70
Exemple #52
0
            f2id = 'f-%d-%d' % (i, j)
            folder2 = Dummy("/plone/%s/%s" % (f1id, f2id))
            count += 1
            index.index_object(count, folder2)

            for k in range(f3):
                f3id = 'f-%d-%d-%d' % (i, j, k)
                folder3 = Dummy("/plone/%s/%s/%s" % (f1id, f2id, f3id))
                count += 1
                index.index_object(count, folder3)

                for m in range(500):
                    docid = 'f-%d-%d-%d-%d' % (i, j, k, m)
                    doc = Dummy("/plone/%s/%s/%s/%s" % (f1id, f2id, f3id, docid))
                    count += 1
                    index.index_object(count, doc)
    print 'Created %s entries' % count

buildTree(index, 20,20)

plone = Root('plone')
plone.index = index
root['plone'] = plone

transaction.commit()
conn.close()

db.close()
storage.close()
#storage.cleanup() # For removing all files
class Indexer:

    filestorage = database = connection = root = None

    def __init__(self, datafs, writable=0, trans=0, pack=0):
        self.trans_limit = trans
        self.pack_limit = pack
        self.trans_count = 0
        self.pack_count = 0
        self.stopdict = get_stopdict()
        self.mh = mhlib.MH()
        self.filestorage = FileStorage(datafs, read_only=(not writable))
        self.database = DB(self.filestorage)
        self.connection = self.database.open()
        self.root = self.connection.root()
        try:
            self.index = self.root["index"]
        except KeyError:
            self.index = self.root["index"] = TextIndex()
        try:
            self.docpaths = self.root["docpaths"]
        except KeyError:
            self.docpaths = self.root["docpaths"] = IOBTree()
        try:
            self.doctimes = self.root["doctimes"]
        except KeyError:
            self.doctimes = self.root["doctimes"] = IIBTree()
        try:
            self.watchfolders = self.root["watchfolders"]
        except KeyError:
            self.watchfolders = self.root["watchfolders"] = {}
        self.path2docid = OIBTree()
        for docid in self.docpaths.keys():
            path = self.docpaths[docid]
            self.path2docid[path] = docid
        try:
            self.maxdocid = max(self.docpaths.keys())
        except ValueError:
            self.maxdocid = 0
        print len(self.docpaths), "Document ids"
        print len(self.path2docid), "Pathnames"
        print self.index.lexicon.length(), "Words"

    def dumpfreqs(self):
        lexicon = self.index.lexicon
        index = self.index.index
        assert isinstance(index, OkapiIndex)
        L = []
        for wid in lexicon.wids():
            freq = 0
            for f in index._wordinfo.get(wid, {}).values():
                freq += f
            L.append((freq, wid, lexicon.get_word(wid)))
        L.sort()
        L.reverse()
        for freq, wid, word in L:
            print "%10d %10d %s" % (wid, freq, word)

    def dumpwids(self):
        lexicon = self.index.lexicon
        index = self.index.index
        assert isinstance(index, OkapiIndex)
        for wid in lexicon.wids():
            freq = 0
            for f in index._wordinfo.get(wid, {}).values():
                freq += f
            print "%10d %10d %s" % (wid, freq, lexicon.get_word(wid))

    def dumpwords(self):
        lexicon = self.index.lexicon
        index = self.index.index
        assert isinstance(index, OkapiIndex)
        for word in lexicon.words():
            wid = lexicon.get_wid(word)
            freq = 0
            for f in index._wordinfo.get(wid, {}).values():
                freq += f
            print "%10d %10d %s" % (wid, freq, word)

    def close(self):
        self.root = None
        if self.connection is not None:
            self.connection.close()
            self.connection = None
        if self.database is not None:
            self.database.close()
            self.database = None
        if self.filestorage is not None:
            self.filestorage.close()
            self.filestorage = None

    def interact(self, nbest=NBEST, maxlines=MAXLINES):
        try:
            import readline
        except ImportError:
            pass
        text = ""
        top = 0
        results = []
        while 1:
            try:
                line = raw_input("Query: ")
            except EOFError:
                print "\nBye."
                break
            line = line.strip()
            if line.startswith("/"):
                self.specialcommand(line, results, top - nbest)
                continue
            if line:
                text = line
                top = 0
            else:
                if not text:
                    continue
            try:
                results, n = self.timequery(text, top + nbest)
            except KeyboardInterrupt:
                raise
            except:
                reportexc()
                text = ""
                continue
            if len(results) <= top:
                if not n:
                    print "No hits for %r." % text
                else:
                    print "No more hits for %r." % text
                text = ""
                continue
            print "[Results %d-%d from %d" % (top + 1, min(n, top + nbest), n),
            print "for query %s]" % repr(text)
            self.formatresults(text, results, maxlines, top, top + nbest)
            top += nbest

    def specialcommand(self, line, results, first):
        assert line.startswith("/")
        line = line[1:]
        if not line:
            n = first
        else:
            try:
                n = int(line) - 1
            except:
                print "Huh?"
                return
        if n < 0 or n >= len(results):
            print "Out of range"
            return
        docid, score = results[n]
        path = self.docpaths[docid]
        i = path.rfind("/")
        assert i > 0
        folder = path[:i]
        n = path[i + 1:]
        cmd = "show +%s %s" % (folder, n)
        if os.getenv("DISPLAY"):
            os.system("xterm -e  sh -c '%s | less' &" % cmd)
        else:
            os.system(cmd)

    def query(self, text, nbest=NBEST, maxlines=MAXLINES):
        results, n = self.timequery(text, nbest)
        if not n:
            print "No hits for %r." % text
            return
        print "[Results 1-%d from %d]" % (len(results), n)
        self.formatresults(text, results, maxlines)

    def timequery(self, text, nbest):
        t0 = time.time()
        c0 = time.clock()
        results, n = self.index.query(text, nbest)
        t1 = time.time()
        c1 = time.clock()
        print "[Query time: %.3f real, %.3f user]" % (t1 - t0, c1 - c0)
        return results, n

    def formatresults(self,
                      text,
                      results,
                      maxlines=MAXLINES,
                      lo=0,
                      hi=sys.maxint):
        stop = self.stopdict.has_key
        words = [w for w in re.findall(r"\w+\*?", text.lower()) if not stop(w)]
        pattern = r"\b(" + "|".join(words) + r")\b"
        pattern = pattern.replace("*", ".*")  # glob -> re syntax
        prog = re.compile(pattern, re.IGNORECASE)
        print '=' * 70
        rank = lo
        qw = self.index.query_weight(text)
        for docid, score in results[lo:hi]:
            rank += 1
            path = self.docpaths[docid]
            score = 100.0 * score / qw
            print "Rank:    %d   Score: %d%%   File: %s" % (rank, score, path)
            path = os.path.join(self.mh.getpath(), path)
            try:
                fp = open(path)
            except (IOError, OSError), msg:
                print "Can't open:", msg
                continue
            msg = mhlib.Message("<folder>", 0, fp)
            for header in "From", "To", "Cc", "Bcc", "Subject", "Date":
                h = msg.getheader(header)
                if h:
                    print "%-8s %s" % (header + ":", h)
            text = self.getmessagetext(msg)
            if text:
                print
                nleft = maxlines
                for part in text:
                    for line in part.splitlines():
                        if prog.search(line):
                            print line
                            nleft -= 1
                            if nleft <= 0:
                                break
                    if nleft <= 0:
                        break
            print '-' * 70
class QueueConflictTests(unittest.TestCase):
    def _setAlternativePolicy(self):
        # Apply the alternative conflict resolution policy
        self.queue._conflict_policy = ALTERNATIVE_POLICY
        self.queue._p_jar.transaction_manager.commit()
        self.queue2._p_jar.sync()

        self.assertEquals(self.queue._conflict_policy, ALTERNATIVE_POLICY)
        self.assertEquals(self.queue2._conflict_policy, ALTERNATIVE_POLICY)

    def _insane_update(self, queue, uid, etype):
        # Queue update method that allows insane state changes, needed
        # to provoke pathological queue states
        data = queue._data
        current = data.get(uid)
        if current is not None:
            generation, current = current

            if ((current is ADDED or current is CHANGED_ADDED)
                    and etype is CHANGED):
                etype = CHANGED_ADDED
        else:
            generation = 0

        data[uid] = generation + 1, etype

        queue._p_changed = 1

    def openDB(self):
        from ZODB.FileStorage import FileStorage
        from ZODB.DB import DB
        self.dir = tempfile.mkdtemp()
        self.storage = FileStorage(os.path.join(self.dir,
                                                'testQCConflicts.fs'))
        self.db = DB(self.storage)

    def setUp(self):
        self.openDB()
        queue = CatalogEventQueue()

        tm1 = transaction.TransactionManager()
        self.conn1 = self.db.open(transaction_manager=tm1)
        r1 = self.conn1.root()
        r1["queue"] = queue
        del queue
        self.queue = r1["queue"]
        tm1.commit()

        tm2 = transaction.TransactionManager()
        self.conn2 = self.db.open(transaction_manager=tm2)
        r2 = self.conn2.root()
        self.queue2 = r2["queue"]
        ignored = dir(self.queue2)  # unghostify

    def tearDown(self):
        transaction.abort()
        del self.queue
        del self.queue2
        if self.storage is not None:
            self.storage.close()
            self.storage.cleanup()
            shutil.rmtree(self.dir)

    def test_rig(self):
        # Test the test rig
        self.assertEqual(self.queue._p_serial, self.queue2._p_serial)

    def test_simpleConflict(self):
        # Using the first connection, index 10 paths
        for n in range(10):
            self.queue.update('/f%i' % n, ADDED)
        self.queue._p_jar.transaction_manager.commit()

        # After this run, the first connection's queuecatalog has 10
        # entries, the second has none.
        self.assertEqual(len(self.queue), 10)
        self.assertEqual(len(self.queue2), 0)

        # Using the second connection, index the other 10 folders
        for n in range(10):
            self.queue2.update('/g%i' % n, ADDED)

        # Now both connections' queuecatalogs have 10 entries each, but
        # for differrent objects
        self.assertEqual(len(self.queue), 10)
        self.assertEqual(len(self.queue2), 10)

        # Now we commit. Conflict resolution on the catalog queue should
        # kick in because both connections have changes. Since none of the
        # events collide, we should end up with 20 entries in our catalogs.
        self.queue2._p_jar.transaction_manager.commit()
        self.queue._p_jar.sync()
        self.queue2._p_jar.sync()
        self.assertEqual(len(self.queue), 20)
        self.assertEqual(len(self.queue2), 20)

    def test_unresolved_add_after_something(self):
        # If an  event is encountered for an object and we are trying to
        # commit an ADDED event, a conflict is encountered

        # Mutilate the logger so we don't see complaints about the
        # conflict we are about to provoke
        from Products.QueueCatalog.QueueCatalog import logger
        logger.disabled = 1

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        self.queue2.update('/f0', ADDED)
        self.queue2.update('/f0', CHANGED)
        self.queue2._p_jar.transaction_manager.commit()

        self._insane_update(self.queue, '/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        self._insane_update(self.queue2, '/f0', ADDED)
        self.assertRaises(ConflictError,
                          self.queue2._p_jar.transaction_manager.commit)

        # cleanup the logger
        logger.disabled = 0

    def test_resolved_add_after_nonremoval(self):
        # If an  event is encountered for an object and we are trying to
        # commit an ADDED event while the conflict resolution policy is
        # NOT the SAFE_POLICY, we won't get a conflict.
        self._setAlternativePolicy()

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        self.queue2.update('/f0', ADDED)
        self.queue2.update('/f0', CHANGED)
        self.queue2._p_jar.transaction_manager.commit()

        self._insane_update(self.queue, '/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        # If we had a conflict, this would blow up
        self._insane_update(self.queue2, '/f0', ADDED)
        self.queue2._p_jar.transaction_manager.commit()

        # After the conflict has been resolved, we expect the queues to
        # containa a CHANGED_ADDED event.
        self.queue._p_jar.sync()
        self.queue2._p_jar.sync()
        self.assertEquals(len(self.queue), 1)
        self.assertEquals(len(self.queue2), 1)
        event1 = self.queue.getEvent('/f0')
        event2 = self.queue2.getEvent('/f0')
        self.failUnless(event1 == event2 == CHANGED_ADDED)

    def test_resolved_add_after_removal(self):
        # If a REMOVED event is encountered for an object and we are trying to
        # commit an ADDED event while the conflict resolution policy is
        # NOT the SAFE_POLICY, we won't get a conflict.
        self._setAlternativePolicy()

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        self.queue2.update('/f0', ADDED)
        self.queue2.update('/f0', CHANGED)
        self.queue2._p_jar.transaction_manager.commit()

        self.queue.update('/f0', REMOVED)
        self.queue._p_jar.transaction_manager.commit()

        # If we had a conflict, this would blow up
        self._insane_update(self.queue2, '/f0', ADDED)
        self.queue2._p_jar.transaction_manager.commit()

        # After the conflict has been resolved, we expect the queue to
        # contain a REMOVED event.
        self.queue._p_jar.sync()
        self.queue2._p_jar.sync()
        self.assertEquals(len(self.queue), 1)
        self.assertEquals(len(self.queue2), 1)
        event1 = self.queue.getEvent('/f0')
        event2 = self.queue2.getEvent('/f0')
        self.failUnless(event1 == event2 == REMOVED)

    def test_unresolved_new_old_current_all_different(self):
        # If the events we get from the current, new and old states are
        # all different, we throw in the towel in the form of a conflict.
        # This test relies on the fact that no OLD state is de-facto treated
        # as a state.

        # Mutilate the logger so we don't see complaints about the
        # conflict we are about to provoke
        from Products.QueueCatalog.QueueCatalog import logger
        logger.disabled = 1

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        # This commit should now raise a conflict
        self._insane_update(self.queue2, '/f0', REMOVED)
        self.assertRaises(ConflictError,
                          self.queue2._p_jar.transaction_manager.commit)

        # cleanup the logger
        logger.disabled = 0

    def test_resolved_new_old_current_all_different(self):
        # If the events we get from the current, new and old states are
        # all different and the SAFE_POLICY conflict resolution policy is
        # not enforced, the conflict resolves without bloodshed.
        # This test relies on the fact that no OLD state is de-facto treated
        # as a state.
        self._setAlternativePolicy()

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        # This commit should not raise a conflict
        self._insane_update(self.queue2, '/f0', REMOVED)
        self.queue2._p_jar.transaction_manager.commit()

        # In this scenario (the incoming new state has a REMOVED event),
        # the new state is disregarded and the old state is used. We are
        # left with a CHANGED_ADDED event. (see queue.update method; ADDED
        # plus CHANGED results in CHANGED_ADDED)
        self.queue._p_jar.sync()
        self.queue2._p_jar.sync()
        self.assertEquals(len(self.queue), 1)
        self.assertEquals(len(self.queue2), 1)
        event1 = self.queue.getEvent('/f0')
        event2 = self.queue2.getEvent('/f0')
        self.failUnless(event1 == event2 == CHANGED_ADDED)

    def test_unresolved_new_old_current_all_different_2(self):
        # If the events we get from the current, new and old states are
        # all different, we throw in the towel in the form of a conflict.
        # This test relies on the fact that no OLD state is de-facto treated
        # as a state.

        # Mutilate the logger so we don't see complaints about the
        # conflict we are about to provoke
        from Products.QueueCatalog.QueueCatalog import logger
        logger.disabled = 1

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        self.queue2.update('/f0', ADDED)
        self.queue2.update('/f0', CHANGED)
        self.queue2._p_jar.transaction_manager.commit()

        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        # This commit should now raise a conflict
        self._insane_update(self.queue2, '/f0', REMOVED)
        self.assertRaises(ConflictError,
                          self.queue2._p_jar.transaction_manager.commit)

        # cleanup the logger
        logger.disabled = 0

    def test_resolved_new_old_current_all_different_2(self):
        # If the events we get from the current, new and old states are
        # all different and the SAFE_POLICY conflict resolution policy is
        # not enforced, the conflict resolves without bloodshed.
        self._setAlternativePolicy()

        self.queue.update('/f0', ADDED)
        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        self.queue2.update('/f0', ADDED)
        self.queue2.update('/f0', CHANGED)
        self.queue2._p_jar.transaction_manager.commit()

        self.queue.update('/f0', CHANGED)
        self.queue._p_jar.transaction_manager.commit()

        # This commit should not raise a conflict
        self._insane_update(self.queue2, '/f0', REMOVED)
        self.queue2._p_jar.transaction_manager.commit()

        # In this scenario (the incoming new state has a REMOVED event),
        # we will take the new state to resolve the conflict, because its
        # generation number is higher then the oldstate and current state.
        self.queue._p_jar.sync()
        self.queue2._p_jar.sync()
        self.assertEquals(len(self.queue), 1)
        self.assertEquals(len(self.queue2), 1)
        event1 = self.queue.getEvent('/f0')
        event2 = self.queue2.getEvent('/f0')
        self.failUnless(event1 == event2 == REMOVED)
Exemple #55
0
class HistoryTests(unittest.TestCase):
    def setUp(self):
        # set up a zodb
        # we can't use DemoStorage here 'cos it doesn't support History
        self.dir = tempfile.mkdtemp()
        self.s = FileStorage(os.path.join(self.dir, 'testHistory.fs'),
                             create=True)
        self.connection = ZODB.DB(self.s).open()
        r = self.connection.root()
        a = Application()
        r['Application'] = a
        self.root = a
        # create a python script
        manage_addPythonScript(a, 'test')
        self.ps = ps = a.test
        # commit some changes
        ps.write('return 1')
        t = transaction.get()
        # undo note made by Application instantiation above.
        t.description = None
        t.note('Change 1')
        t.commit()
        ps.write('return 2')
        t = transaction.get()
        t.note('Change 2')
        t.commit()
        ps.write('return 3')
        t = transaction.get()
        t.note('Change 3')
        t.commit()

    def tearDown(self):
        # get rid of ZODB
        transaction.abort()
        self.connection.close()
        self.s.close()
        del self.root
        del self.connection
        del self.s
        shutil.rmtree(self.dir)

    def test_manage_change_history(self):
        r = self.ps.manage_change_history()
        self.assertEqual(len(r), 3)  # three transactions
        for i in range(3):
            entry = r[i]
            # check no new keys show up without testing
            self.assertEqual(len(entry.keys()), 7)
            # the transactions are in newest-first order
            self.assertEqual(entry['description'], 'Change %i' % (3 - i))
            self.failUnless('key' in entry)
            # lets not assume the size will stay the same forever
            self.failUnless('size' in entry)
            self.failUnless('tid' in entry)
            self.failUnless('time' in entry)
            if i:
                # check times are increasing
                self.failUnless(entry['time'] < r[i - 1]['time'])
            self.assertEqual(entry['user_name'], '')
            self.assertEqual(entry['version'], '')

    def test_manage_historyCopy(self):
        # we assume this works 'cos it's tested above
        r = self.ps.manage_change_history()
        # now we do the copy
        self.ps.manage_historyCopy(keys=[r[2]['key']])
        # do a commit, just like ZPublisher would
        transaction.commit()
        # check the body is as it should be, we assume (hopefully not foolishly)
        # that all other attributes will behave the same
        self.assertEqual(self.ps._body, 'return 1\n')
Exemple #56
0
class RecoverTest(ZODB.tests.util.TestCase):

    path = None

    def setUp(self):
        ZODB.tests.util.TestCase.setUp(self)
        self.path = 'source.fs'
        self.storage = FileStorage(self.path)
        self.populate()
        self.dest = 'dest.fs'
        self.recovered = None

    def tearDown(self):
        self.storage.close()
        if self.recovered is not None:
            self.recovered.close()
        temp = FileStorage(self.dest)
        temp.close()
        ZODB.tests.util.TestCase.tearDown(self)

    def populate(self):
        db = ZODB.DB(self.storage)
        cn = db.open()
        rt = cn.root()

        # Create a bunch of objects; the Data.fs is about 100KB.
        for i in range(50):
            d = rt[i] = PersistentMapping()
            transaction.commit()
            for j in range(50):
                d[j] = "a" * j
            transaction.commit()

    def damage(self, num, size):
        self.storage.close()
        # Drop size null bytes into num random spots.
        for i in range(num):
            offset = random.randint(0, self.storage._pos - size)
            f = open(self.path, "a+b")
            f.seek(offset)
            f.write("\0" * size)
            f.close()

    ITERATIONS = 5

    # Run recovery, from self.path to self.dest.  Return whatever
    # recovery printed to stdout, as a string.
    def recover(self):
        orig_stdout = sys.stdout
        faux_stdout = StringIO.StringIO()
        try:
            sys.stdout = faux_stdout
            try:
                ZODB.fsrecover.recover(self.path,
                                       self.dest,
                                       verbose=0,
                                       partial=True,
                                       force=False,
                                       pack=1)
            except SystemExit:
                raise RuntimeError("recover tried to exit")
        finally:
            sys.stdout = orig_stdout
        return faux_stdout.getvalue()

    # Caution:  because recovery is robust against many kinds of damage,
    # it's almost impossible for a call to self.recover() to raise an
    # exception.  As a result, these tests may pass even if fsrecover.py
    # is broken badly.  testNoDamage() tries to ensure that at least
    # recovery doesn't produce any error msgs if the input .fs is in
    # fact not damaged.
    def testNoDamage(self):
        output = self.recover()
        self.assert_('error' not in output, output)
        self.assert_('\n0 bytes removed during recovery' in output, output)

        # Verify that the recovered database is identical to the original.
        before = file(self.path, 'rb')
        before_guts = before.read()
        before.close()

        after = file(self.dest, 'rb')
        after_guts = after.read()
        after.close()

        self.assertEqual(before_guts, after_guts,
                         "recovery changed a non-damaged .fs file")

    def testOneBlock(self):
        for i in range(self.ITERATIONS):
            self.damage(1, 1024)
            output = self.recover()
            self.assert_('error' in output, output)
            self.recovered = FileStorage(self.dest)
            self.recovered.close()
            os.remove(self.path)
            os.rename(self.dest, self.path)

    def testFourBlocks(self):
        for i in range(self.ITERATIONS):
            self.damage(4, 512)
            output = self.recover()
            self.assert_('error' in output, output)
            self.recovered = FileStorage(self.dest)
            self.recovered.close()
            os.remove(self.path)
            os.rename(self.dest, self.path)

    def testBigBlock(self):
        for i in range(self.ITERATIONS):
            self.damage(1, 32 * 1024)
            output = self.recover()
            self.assert_('error' in output, output)
            self.recovered = FileStorage(self.dest)
            self.recovered.close()
            os.remove(self.path)
            os.rename(self.dest, self.path)

    def testBadTransaction(self):
        # Find transaction headers and blast them.

        L = self.storage.undoLog()
        r = L[3]
        tid = base64.decodestring(r["id"] + "\n")
        pos1 = self.storage._txn_find(tid, 0)

        r = L[8]
        tid = base64.decodestring(r["id"] + "\n")
        pos2 = self.storage._txn_find(tid, 0)

        self.storage.close()

        # Overwrite the entire header.
        f = open(self.path, "a+b")
        f.seek(pos1 - 50)
        f.write("\0" * 100)
        f.close()
        output = self.recover()
        self.assert_('error' in output, output)
        self.recovered = FileStorage(self.dest)
        self.recovered.close()
        os.remove(self.path)
        os.rename(self.dest, self.path)

        # Overwrite part of the header.
        f = open(self.path, "a+b")
        f.seek(pos2 + 10)
        f.write("\0" * 100)
        f.close()
        output = self.recover()
        self.assert_('error' in output, output)
        self.recovered = FileStorage(self.dest)
        self.recovered.close()

    # Issue 1846:  When a transaction had 'c' status (not yet committed),
    # the attempt to open a temp file to write the trailing bytes fell
    # into an infinite loop.
    def testUncommittedAtEnd(self):
        # Find a transaction near the end.
        L = self.storage.undoLog()
        r = L[1]
        tid = base64.decodestring(r["id"] + "\n")
        pos = self.storage._txn_find(tid, 0)

        # Overwrite its status with 'c'.
        f = open(self.path, "r+b")
        f.seek(pos + 16)
        current_status = f.read(1)
        self.assertEqual(current_status, ' ')
        f.seek(pos + 16)
        f.write('c')
        f.close()

        # Try to recover.  The original bug was that this never completed --
        # infinite loop in fsrecover.py.  Also, in the ZODB 3.2 line,
        # reference to an undefined global masked the infinite loop.
        self.recover()

        # Verify the destination got truncated.
        self.assertEqual(os.path.getsize(self.dest), pos)

        # Get rid of the temp file holding the truncated bytes.
        os.remove(ZODB.fsrecover._trname)