def test_block_change():

    with use_temporary_filename('tests/test_catches_block_change.pkl') as file_path:

        pod1 = PersistentOrderedDict(file_path)
        pod1['a'] = 1
        assert list(pod1.items()) == [('a', 1)]
        pod2 = PersistentOrderedDict(file_path)

        with pod2:
            pod2['b'] = 2
            assert list(pod1.items()) == [('a', 1)]
        assert list(pod1.items()) == [('a', 1), ('b', 2)]
Пример #2
0
def test_block_change():

    with use_temporary_filename(
            'tests/test_catches_block_change.pkl') as file_path:

        pod1 = PersistentOrderedDict(file_path)
        pod1['a'] = 1
        assert list(pod1.items()) == [('a', 1)]
        pod2 = PersistentOrderedDict(file_path)

        with pod2:
            pod2['b'] = 2
            assert list(pod1.items()) == [('a', 1)]
        assert list(pod1.items()) == [('a', 1), ('b', 2)]
Пример #3
0
class ExperimentRecordInfo(object):

    def __init__(self, file_path, write_text_version=True):
        before, ext = os.path.splitext(file_path)
        assert ext == '.pkl', 'Your file-path must be a pickle'
        self._text_path = before + '.txt' if write_text_version else None
        self.persistent_obj = PersistentOrderedDict(file_path=file_path)

    def has_field(self, field):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        return field in self.persistent_obj

    def get_field(self, field):
        """
        :param field: A member of ExperimentRecordInfo.FIELDS
        :return:
        """
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        return self.persistent_obj[field]

    def get_status_field(self):
        if self.has_field(ExpInfoFields.STATUS):
            return self.persistent_obj[ExpInfoFields.STATUS]
        else:
            return ExpStatusOptions.CORRUPT

    def set_field(self, field, value):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        if field == ExpInfoFields.STATUS:
            assert value in ExpStatusOptions, 'Status value must be in: {}'.format(ExpStatusOptions)
        with self.persistent_obj as pod:
            pod[field] = value
        if self._text_path is not None:
            with open(self._text_path, 'w') as f:
                f.write(self.get_text())

    def add_note(self, note):  # Currently unused
        if not self.has_field(ExpInfoFields.NOTES):
            self.set_field(ExpInfoFields.NOTES, [note])
        else:
            self.set_field(ExpInfoFields.NOTES, self.get_field(ExpInfoFields.NOTES) + [note])

    def get_text(self):
        if ExpInfoFields.VERSION not in self.persistent_obj:  # Old version... we must adapt
            return '\n'.join(
                '{}: {}'.format(key, self.get_field_text(key)) for key, value in self.persistent_obj.items())
        else:
            return '\n'.join(
                '{}: {}'.format(key.value, self.get_field_text(key)) for key, value in self.persistent_obj.items())

    def get_field_text(self, field, replacement_if_none=''):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        if not self.has_field(field):
            return replacement_if_none
        elif field is ExpInfoFields.STATUS:
            return self.get_field(field).value
        elif field is ExpInfoFields.ARGS:
            return ['{}={}'.format(k, v) for k, v in self.get_field(field)]
        else:
            return str(self.get_field(field))
def test_persistent_ordered_dict():

    with use_temporary_filename('tests/podtest.pkl') as file_path:

        pod = PersistentOrderedDict(file_path)
        assert list(pod.items()) == []
        pod['a'] = [1, 2, 3]
        pod['b'] = [4, 5, 6]
        pod['c'] = [7, 8]

        pod2 = PersistentOrderedDict(file_path)
        assert list(pod2.items()) == [('a', [1, 2, 3]), ('b', [4, 5, 6]), ('c', [7, 8])]
        pod['e']=11

        pod3 = PersistentOrderedDict(file_path)
        assert list(pod3.items()) == [('a', [1, 2, 3]), ('b', [4, 5, 6]), ('c', [7, 8]), ('e', 11)]
Пример #5
0
def test_persistent_ordered_dict():

    with use_temporary_filename('tests/podtest.pkl') as file_path:

        pod = PersistentOrderedDict(file_path)
        assert list(pod.items()) == []
        pod['a'] = [1, 2, 3]
        pod['b'] = [4, 5, 6]
        pod['c'] = [7, 8]

        pod2 = PersistentOrderedDict(file_path)
        assert list(pod2.items()) == [('a', [1, 2, 3]), ('b', [4, 5, 6]),
                                      ('c', [7, 8])]
        pod['e'] = 11

        pod3 = PersistentOrderedDict(file_path)
        assert list(pod3.items()) == [('a', [1, 2, 3]), ('b', [4, 5, 6]),
                                      ('c', [7, 8]), ('e', 11)]
def test_catches_modifications():

    with use_temporary_filename('tests/test_catches_modifications.pkl') as file_path:

        pod1 = PersistentOrderedDict(file_path)

        pod1['a'] = 1

        pod2 = PersistentOrderedDict(file_path)
        assert pod2['a']==1

        pod1['a'] = 2
        assert pod2['a']==2

        pod2['a']=3
        assert pod1['a']==3

        pod2['b']=4
        assert list(pod1.items()) == [('a', 3), ('b', 4)]

        pod3 = PersistentOrderedDict(file_path, items=[('b', 5), ('c', 6)])
        assert list(pod1.items())==list(pod2.items())==list(pod3.items()) == [('a', 3), ('b', 5), ('c', 6)]
Пример #7
0
def test_catches_modifications():

    with use_temporary_filename(
            'tests/test_catches_modifications.pkl') as file_path:

        pod1 = PersistentOrderedDict(file_path)

        pod1['a'] = 1

        pod2 = PersistentOrderedDict(file_path)
        assert pod2['a'] == 1

        pod1['a'] = 2
        assert pod2['a'] == 2

        pod2['a'] = 3
        assert pod1['a'] == 3

        pod2['b'] = 4
        assert list(pod1.items()) == [('a', 3), ('b', 4)]

        pod3 = PersistentOrderedDict(file_path, items=[('b', 5), ('c', 6)])
        assert list(pod1.items()) == list(pod2.items()) == list(
            pod3.items()) == [('a', 3), ('b', 5), ('c', 6)]
Пример #8
0
class ExperimentRecordInfo(object):
    def __init__(self, file_path, write_text_version=True):
        before, ext = os.path.splitext(file_path)
        assert ext == '.pkl', 'Your file-path must be a pickle'
        self._text_path = before + '.txt' if write_text_version else None
        self.persistent_obj = PersistentOrderedDict(file_path=file_path)

    def has_field(self, field):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        return field in self.persistent_obj

    def get_field(self, field, default=ERROR_FLAG):
        """
        :param field: A member of ExperimentRecordInfo.FIELDS
        :param default: Default to return if field does not exist (if left unspecified, we raise error)
        :return: The info for that field.
        """
        try:
            return self.persistent_obj[field]
        except KeyError:
            if default is ERROR_FLAG:
                raise
            else:
                return default

    def get_status_field(self):
        if self.has_field(ExpInfoFields.STATUS):
            return self.persistent_obj[ExpInfoFields.STATUS]
        else:
            return ExpStatusOptions.CORRUPT

    def set_field(self, field, value):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        if field == ExpInfoFields.STATUS:
            assert value in ExpStatusOptions, 'Status value must be in: {}'.format(
                ExpStatusOptions)
        self.persistent_obj[field] = value
        if self._text_path is not None:
            with open(self._text_path, 'w') as f:
                f.write(self.get_text())

    def add_note(self, note):  # Currently unused
        if not self.has_field(ExpInfoFields.NOTES):
            self.set_field(ExpInfoFields.NOTES, [note])
        else:
            self.set_field(ExpInfoFields.NOTES,
                           self.get_field(ExpInfoFields.NOTES) + [note])

    def get_notes(self):
        return [] if not self.has_field(
            ExpInfoFields.NOTES) else self.get_field(ExpInfoFields.NOTES)

    def get_text(self):
        if ExpInfoFields.VERSION not in self.persistent_obj:  # Old version... we must adapt
            return '\n'.join('{}: {}'.format(key, self.get_field_text(key))
                             for key, value in self.persistent_obj.items())
        else:
            return '\n'.join(
                '{}: {}'.format(key.value, self.get_field_text(key))
                for key, value in self.persistent_obj.items())

    def get_field_text(self, field, replacement_if_none=''):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        if not self.has_field(field):
            return replacement_if_none
        elif field is ExpInfoFields.STATUS:
            return self.get_field(field).value
        elif field is ExpInfoFields.ARGS:
            args = load_serialized_args(self.get_field(field))
            return ['{}={}'.format(k, v) for k, v in args]
        else:
            return str(self.get_field(field))
Пример #9
0
class ExperimentRecordInfo(object):

    def __init__(self, file_path, write_text_version=True):
        before, ext = os.path.splitext(file_path)
        assert ext == '.pkl', 'Your file-path must be a pickle'
        self._text_path = before + '.txt' if write_text_version else None
        self.persistent_obj = PersistentOrderedDict(file_path=file_path)

    def has_field(self, field):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        return field in self.persistent_obj

    def get_field(self, field, default = ERROR_FLAG):
        """
        :param field: A member of ExperimentRecordInfo.FIELDS
        :param default: Default to return if field does not exist (if left unspecified, we raise error)
        :return: The info for that field.
        """
        try:
            return self.persistent_obj[field]
        except KeyError:
            if default is ERROR_FLAG:
                raise
            else:
                return default

    def get_status_field(self):
        if self.has_field(ExpInfoFields.STATUS):
            return self.persistent_obj[ExpInfoFields.STATUS]
        else:
            return ExpStatusOptions.CORRUPT

    def set_field(self, field, value):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        if field == ExpInfoFields.STATUS:
            assert value in ExpStatusOptions, 'Status value must be in: {}'.format(ExpStatusOptions)
        self.persistent_obj[field] = value
        if self._text_path is not None:
            with open(self._text_path, 'w') as f:
                f.write(self.get_text())

    def add_note(self, note):  # Currently unused
        if not self.has_field(ExpInfoFields.NOTES):
            self.set_field(ExpInfoFields.NOTES, [note])
        else:
            self.set_field(ExpInfoFields.NOTES, self.get_field(ExpInfoFields.NOTES) + [note])

    def get_notes(self):
        return [] if not self.has_field(ExpInfoFields.NOTES) else self.get_field(ExpInfoFields.NOTES)

    def get_text(self):
        if ExpInfoFields.VERSION not in self.persistent_obj:  # Old version... we must adapt
            return '\n'.join(
                '{}: {}'.format(key, self.get_field_text(key)) for key, value in self.persistent_obj.items())
        else:
            return '\n'.join(
                '{}: {}'.format(key.value, self.get_field_text(key)) for key, value in self.persistent_obj.items())

    def get_field_text(self, field, replacement_if_none=''):
        assert field in ExpInfoFields, 'Field must be a member of ExperimentRecordInfo.FIELDS'
        if not self.has_field(field):
            return replacement_if_none
        elif field is ExpInfoFields.STATUS:
            return self.get_field(field).value
        elif field is ExpInfoFields.ARGS:
            args = load_serialized_args(self.get_field(field))
            return ['{}={}'.format(k, v) for k, v in args]
        else:
            return str(self.get_field(field))