Beispiel #1
0
    def __test_restore_elsewhere(
        self,
        model: ModelPT,
        attr_for_eq_check: Set[str] = None,
        override_config_path: Optional[Union[str, DictConfig]] = None,
        map_location: Optional[torch.device] = None,
        strict: bool = False,
        return_config: bool = False,
    ):
        """Test's logic:
            1. Save model into temporary folder (save_folder)
            2. Copy .nemo file from save_folder to restore_folder
            3. Delete save_folder
            4. Attempt to restore from .nemo file in restore_folder and compare to original instance
        """
        # Create a new temporary directory
        with tempfile.TemporaryDirectory() as restore_folder:
            with tempfile.TemporaryDirectory() as save_folder:
                save_folder_path = save_folder
                # Where model will be saved
                model_save_path = os.path.join(
                    save_folder, f"{model.__class__.__name__}.nemo")
                model.save_to(save_path=model_save_path)
                # Where model will be restored from
                model_restore_path = os.path.join(
                    restore_folder, f"{model.__class__.__name__}.nemo")
                shutil.copy(model_save_path, model_restore_path)
            # at this point save_folder should not exist
            assert save_folder_path is not None and not os.path.exists(
                save_folder_path)
            assert not os.path.exists(model_save_path)
            assert os.path.exists(model_restore_path)
            # attempt to restore
            model_copy = model.__class__.restore_from(
                restore_path=model_restore_path,
                map_location=map_location,
                strict=strict,
                return_config=return_config,
                override_config_path=override_config_path,
            )

            if return_config:
                return model_copy

            assert model.num_weights == model_copy.num_weights
            if attr_for_eq_check is not None and len(attr_for_eq_check) > 0:
                for attr in attr_for_eq_check:
                    assert getattr2(model, attr) == getattr2(model_copy, attr)

            return model_copy
Beispiel #2
0
    def test_mock_save_to_restore_from_with_target_class(self):
        with tempfile.NamedTemporaryFile('w') as empty_file:
            # Write some data
            empty_file.writelines(["*****\n"])
            empty_file.flush()

            # Update config
            cfg = _mock_model_config()
            cfg.model.temp_file = empty_file.name

            # Create model
            model = MockModel(cfg=cfg.model, trainer=None)
            model = model.to('cpu')  # type: MockModel

            assert model.temp_file == empty_file.name

            # Save file using MockModel
            with tempfile.TemporaryDirectory() as save_folder:
                save_path = os.path.join(save_folder, "temp.nemo")
                model.save_to(save_path)

                # Restore test (using ModelPT as restorer)
                # This forces the target class = MockModel to be used as resolver
                model_copy = ModelPT.restore_from(save_path,
                                                  map_location='cpu')
            # because of caching - cache gets prepended
            assert os.path.basename(model_copy.temp_file).endswith(
                os.path.basename(model.temp_file))
            # assert filecmp.cmp(model.temp_file, model_copy.temp_file)
        # Restore test
        diff = model.w.weight - model_copy.w.weight
        assert diff.mean() <= 1e-9
        assert isinstance(model_copy, MockModel)
        assert model_copy.temp_data == ["*****\n"]
Beispiel #3
0
 def test_EncDecCTCModelBPE_HF(self):
     # TODO: Switch to using named configs because here we don't really care about weights
     # Specifically use ModelPT instead of EncDecCTCModelBPE in order to test target class resolution.
     cn = ModelPT.from_pretrained(
         model_name="nvidia/stt_en_conformer_ctc_medium")
     self.__test_restore_elsewhere(
         model=cn,
         attr_for_eq_check=set(["decoder._feat_in",
                                "decoder._num_classes"]))