Ejemplo n.º 1
0
def test_save_load_many_steps(topology_class, n_steps):
    """A variation of the save/load test, for multi-step state changes.

    There can be parts of a node's state (e.g. a random number that was generated) that is not stored immediately in
    memory blocks but manifests themselves in the memory blocks after a number of steps.

    This is test specialized in detecting the failure to save/load those parts of the state.
    """
    topology = topology_class()
    topology.step()
    state_before = TAState(topology)

    with tempfile.TemporaryDirectory() as directory:
        saver = Saver(directory)
        topology.save(saver)
        saver.save()

        for _ in range(n_steps):
            topology.step()
        state_after_steps = TAState(topology)
        assert state_after_steps != state_before

        loader = Loader(directory)
        topology.load(loader)

    state_after_load = TAState(topology)
    assert state_after_load == state_before

    for _ in range(n_steps):
        topology.step()
    state_after_load_and_steps = TAState(topology)
    assert state_after_load_and_steps == state_after_steps
Ejemplo n.º 2
0
def test_save_load(topology_class):
    """Unit test for saving and loading of topologies.

    1. Initialize a topology
    2. Take a snapshot of its state ("before state")
    3. Save the topology
    4. Change it
    5. Verify that it changed from the before state
    6. Load the topology
    7. Verify that it is now equal to the before state
    """
    topology = topology_class()
    topology.step()
    state_before = TAState(topology)

    with tempfile.TemporaryDirectory() as directory:
        saver = Saver(directory)
        topology.save(saver)
        saver.save()

        change_topology(topology)
        state_after_change = TAState(topology)
        assert state_after_change != state_before

        loader = Loader(directory)
        topology.load(loader)

    state_after_load = TAState(topology)
    assert state_before == state_after_load
Ejemplo n.º 3
0
def test_save_load():
    creator = AllocatingCreator('cpu')

    unit = RandomUnitStub(creator)
    unit2 = RandomUnitStub(creator)

    with TemporaryDirectory() as folder:
        saver = Saver(folder)
        unit.save(saver)
        saver.save()

        loader = Loader(folder)
        unit2.load(loader)

    assert same(unit.output, unit2.output)
Ejemplo n.º 4
0
def test_save_load():
    node = RandomNodeStub()
    node2 = RandomNodeStub()

    creator = AllocatingCreator('cpu')
    node.allocate_memory_blocks(creator)
    node2.allocate_memory_blocks(creator)

    with TemporaryDirectory() as folder:
        saver = Saver(folder)
        node.save(saver)
        saver.save()

        loader = Loader(folder)
        node2.load(loader)

    assert same(node._unit.output, node2._unit.output)
Ejemplo n.º 5
0
    def stop(self):
        # No more steps, clean up.
        self._iterator = None

        if self._save_cache:
            self._run_measurement_manager.save_cache()

        if self._save_model_after_run:
            path = os.path.join(self._run_measurement_manager.cache_folder,
                                'saved_models')
            saver = Saver(path)
            self.topology.save(saver, self.topology.name + f"_{self._run_idx}")
            saver.save()

        if self._calculate_statistics:
            self.controller.calculate_run_results()

        self._manager.measurement_manager.add_results(
            self._run_measurement_manager)
Ejemplo n.º 6
0
class PersistableSaver:
    """Saves and loads the persistable for purpose of train/test splitting ExperimentTemplate"""

    _saver: Saver
    _loader: Loader

    def __init__(self, adapter_name: str):
        persistence_path = self._get_persistence_location(adapter_name)

        self._saver = Saver(persistence_path)

        if not os.path.exists(persistence_path):
            logger.error(
                f"There is no saved model at location {persistence_path}")

        self._loader = Loader(persistence_path)

    def save_data_of(self, persistable: Persistable):
        """
        Saves a switchable to a default location
        Args:
            persistable: the switchable to be saved
        """
        persistable.save(self._saver)
        self._saver.save()
        logger.info('Persistable saved')

    def load_data_into(self, persistable: Persistable):
        """
        Loads a persistable from a default location
        Args:
            persistable: the persistable into which the data will be loaded
        """
        try:
            persistable.load(self._loader)
        except FileNotFoundError:
            logger.exception(f"Loading of persistable failed")

        logger.info('Persistable loaded')

    @staticmethod
    def _get_persistence_location(adapter_name: str):
        return os.path.join(os.getcwd(), 'data', 'stored', adapter_name)
Ejemplo n.º 7
0
def test_graph_save_load():
    graph = create_graph()
    graph2 = create_graph()

    graph.step()
    graph2.step()

    graph2.nodes[0].outputs.output.tensor.random_()
    graph2.nodes[1].outputs.output.tensor.random_()
    graph2.nodes[2].outputs.output.tensor.random_()

    with TemporaryDirectory() as folder:
        saver = Saver(folder)
        graph.save(saver)
        saver.save()

        loader = Loader(folder)
        graph2.load(loader)

    for i in range(2):
        assert same(graph.nodes[i].outputs.output.tensor, graph2.nodes[i].outputs.output.tensor)
Ejemplo n.º 8
0
    def test_serialization(self):
        if self.skip_test_serialization():
            pytest.skip()

        """Test that if the node is saved, changed and then loaded, it still computes the expected result."""
        node, sources = self._prepare_node()

        # save the node state after initialization
        with tempfile.TemporaryDirectory() as directory:
            saver = Saver(directory)
            node.save(saver)
            saver.save()

            # run steps to change it
            self._run_node_for_steps(node, sources, check_results=False)

            self._change_node_before_load(node)

            # load the initial state
            loader = Loader(directory)
            node.load(loader)

        # run steps again and check they still produce the desired results
        self._run_node_for_steps(node, sources)