コード例 #1
0
ファイル: test_topologies.py プロジェクト: jvitku/torchsim
def test_save_load_many_steps(topology_class, n_steps):
    """A variation of the save/load test, for multi-step state changes.

    There can be parts of a node's state (e.g. a random number that was generated) that is not stored immediately in
    memory blocks but manifests themselves in the memory blocks after a number of steps.

    This is test specialized in detecting the failure to save/load those parts of the state.
    """
    topology = topology_class()
    topology.step()
    state_before = TAState(topology)

    with tempfile.TemporaryDirectory() as directory:
        saver = Saver(directory)
        topology.save(saver)
        saver.save()

        for _ in range(n_steps):
            topology.step()
        state_after_steps = TAState(topology)
        assert state_after_steps != state_before

        loader = Loader(directory)
        topology.load(loader)

    state_after_load = TAState(topology)
    assert state_after_load == state_before

    for _ in range(n_steps):
        topology.step()
    state_after_load_and_steps = TAState(topology)
    assert state_after_load_and_steps == state_after_steps
コード例 #2
0
ファイル: test_topologies.py プロジェクト: jvitku/torchsim
def test_save_load(topology_class):
    """Unit test for saving and loading of topologies.

    1. Initialize a topology
    2. Take a snapshot of its state ("before state")
    3. Save the topology
    4. Change it
    5. Verify that it changed from the before state
    6. Load the topology
    7. Verify that it is now equal to the before state
    """
    topology = topology_class()
    topology.step()
    state_before = TAState(topology)

    with tempfile.TemporaryDirectory() as directory:
        saver = Saver(directory)
        topology.save(saver)
        saver.save()

        change_topology(topology)
        state_after_change = TAState(topology)
        assert state_after_change != state_before

        loader = Loader(directory)
        topology.load(loader)

    state_after_load = TAState(topology)
    assert state_before == state_after_load
コード例 #3
0
ファイル: unit.py プロジェクト: jvitku/torchsim
    def load(self, parent_loader: Loader):
        loader = parent_loader.load_child(self._unit_folder_name)

        folder_path = loader.get_full_folder_path()

        self.load_tensors(folder_path, loader.description['tensors'])

        self._load(loader)
コード例 #4
0
    def __init__(self, adapter_name: str):
        persistence_path = self._get_persistence_location(adapter_name)

        self._saver = Saver(persistence_path)

        if not os.path.exists(persistence_path):
            logger.error(
                f"There is no saved model at location {persistence_path}")

        self._loader = Loader(persistence_path)
コード例 #5
0
ファイル: nn_node.py プロジェクト: jvitku/torchsim
    def _load(self, loader: Loader):
        super()._load(loader)

        # TODO remove this after deserialization fixed
        logger.error(
            "nn_node.py: loading contains some bug which breaks the learning, "
            +
            "please manually disable loading in the TestableExperimentTemplateBase "
            +
            "(comment out the line 477: self._topology_saver.load_data_into(self._topology) )"
        )

        # self.storage.x = loader.description['store_x']
        self.network = torch.load(
            os.path.join(loader.get_full_folder_path(), 'network.pt'))
コード例 #6
0
def test_save_load():
    creator = AllocatingCreator('cpu')

    unit = RandomUnitStub(creator)
    unit2 = RandomUnitStub(creator)

    with TemporaryDirectory() as folder:
        saver = Saver(folder)
        unit.save(saver)
        saver.save()

        loader = Loader(folder)
        unit2.load(loader)

    assert same(unit.output, unit2.output)
コード例 #7
0
ファイル: test_node_base.py プロジェクト: jvitku/torchsim
def test_save_load():
    node = RandomNodeStub()
    node2 = RandomNodeStub()

    creator = AllocatingCreator('cpu')
    node.allocate_memory_blocks(creator)
    node2.allocate_memory_blocks(creator)

    with TemporaryDirectory() as folder:
        saver = Saver(folder)
        node.save(saver)
        saver.save()

        loader = Loader(folder)
        node2.load(loader)

    assert same(node._unit.output, node2._unit.output)
コード例 #8
0
ファイル: test_graph.py プロジェクト: jvitku/torchsim
def test_graph_save_load():
    graph = create_graph()
    graph2 = create_graph()

    graph.step()
    graph2.step()

    graph2.nodes[0].outputs.output.tensor.random_()
    graph2.nodes[1].outputs.output.tensor.random_()
    graph2.nodes[2].outputs.output.tensor.random_()

    with TemporaryDirectory() as folder:
        saver = Saver(folder)
        graph.save(saver)
        saver.save()

        loader = Loader(folder)
        graph2.load(loader)

    for i in range(2):
        assert same(graph.nodes[i].outputs.output.tensor, graph2.nodes[i].outputs.output.tensor)
コード例 #9
0
    def test_serialization(self):
        if self.skip_test_serialization():
            pytest.skip()

        """Test that if the node is saved, changed and then loaded, it still computes the expected result."""
        node, sources = self._prepare_node()

        # save the node state after initialization
        with tempfile.TemporaryDirectory() as directory:
            saver = Saver(directory)
            node.save(saver)
            saver.save()

            # run steps to change it
            self._run_node_for_steps(node, sources, check_results=False)

            self._change_node_before_load(node)

            # load the initial state
            loader = Loader(directory)
            node.load(loader)

        # run steps again and check they still produce the desired results
        self._run_node_for_steps(node, sources)
コード例 #10
0
ファイル: grid_world_node.py プロジェクト: jvitku/torchsim
    def _load(self, loader: Loader):
        super()._load(loader)

        for i, unit in enumerate(self._units):
            unit.load(loader.load_child(f'sub_unit_{i}'))
コード例 #11
0
    def load(self, parent_loader: Loader):
        """Load the node and its tensors from location relative to the parent loader."""
        folder_name = self._get_persistence_name()
        loader = parent_loader.load_child(folder_name)

        self._load(loader)