Beispiel #1
0
    def test_graph_save_load(self, tmpdir):
        """
            Tests graph saving and loading.
        
            Args:
                tmpdir: Fixture which will provide a temporary directory.
        """

        dl = RealFunctionDataLayer(n=10, batch_size=1)
        tn = TaylorNet(dim=4)
        # Get the "original" weights.
        weights1 = get_state_dict(tn)

        # Create a simple graph.
        with NeuralGraph() as g1:
            x, t = dl()
            p = tn(x=x)

        # Generate filename in the temporary directory.
        tmp_file_name = str(tmpdir.join("tgsl_g1.chkpt"))
        # Save graph.
        g1.save_to(tmp_file_name)

        # Load graph.
        g1.restore_from(tmp_file_name)

        # Get the "restored" weights.
        weights2 = get_state_dict(tn)

        # Compare state dicts.
        for key in weights1:
            assert array_equal(weights1[key].cpu().numpy(), weights2[key].cpu().numpy())
Beispiel #2
0
    def test_state_dict(self):
        """
            Tests whether the get/set_state_dict proxy functions work properly.
        """
        # Module.
        fx = TaylorNet(dim=4)

        # Get state dict.
        state_dict1 = get_state_dict(fx)

        # Set state dict.
        set_state_dict(fx, state_dict1)

        # Compare state dicts.
        state_dict2 = get_state_dict(fx)
        for key in state_dict1.keys():
            assert array_equal(state_dict1[key].cpu().numpy(),
                               state_dict2[key].cpu().numpy())
Beispiel #3
0
    def save_to(self, filename: str, module_names: Optional[List[str]] = None):
        """
        Saves the state of trainable modules in the graph to a checkpoint file.

        Args:
            filename (string): Name of the file where the checkpoint will be saved.
            module_names: List of modules to be frozen (Optional). If passed, all modules will be saved.
        Raises:
            KeyError: If name of the module won't be recognized.
        """
        # Work on all modules.
        if module_names is None:
            module_names = self._modules.keys()

        # Prepare the "graph checkpoint".
        chkpt = {
            "header": {
                "nemo_core_version": nemo_version,
                "name": self.name
            },
            "modules": {}
        }

        log_str = ''
        # Iterate through the modules one by one.
        for name in module_names:
            if name not in self._modules.keys():
                raise KeyError(
                    "Module `{}` not present in the `{}` graph".format(
                        name, self.name))
            # Check module type.
            module = self._modules[name]
            if module.type == ModuleType.trainable:
                # Get module state_dict().
                chkpt["modules"][name] = get_state_dict(module)
                log_str += "  * Module '{}' ({}) params saved \n".format(
                    module.name,
                    type(module).__name__)
            else:
                logging.debug(
                    "Module `{}` is not trainable so cannot be saved".format(
                        name))

        # Save checkpoint.
        save(chkpt, filename)
        log_str = "Saved  the '{}' graph to a checkpoint `{}`:\n".format(
            self.name, filename) + log_str
        logging.info(log_str)
Beispiel #4
0
    def test_save_load(self, tmpdir):
        """
            Tests whether the save and load proxy functions work properly.

            Args:
                tmpdir: Fixture which will provide a temporary directory.
        """
        # Module.
        fx = TaylorNet(dim=4)

        # Generate filename in the temporary directory.
        tmp_file_name = str(tmpdir.join("tsl_taylornet.chkpt"))

        # Save.
        weights = get_state_dict(fx)
        save(weights, tmp_file_name)

        # Load.
        loaded_weights = load(tmp_file_name)

        # Compare state dicts.
        for key in weights:
            assert array_equal(weights[key].cpu().numpy(),
                               loaded_weights[key].cpu().numpy())