def test_module_training_state_unchanged(self, model, trainer, path, training): if training: model.train() else: model.eval() callback = TorchScriptCallback(path) callback.on_train_end(trainer, model) assert model.training == training
def test_module_set_to_eval_mode(self, model, trainer, path, mocker, training): if training: model.train() else: model.eval() spy = mocker.spy(model, "eval") callback = TorchScriptCallback(path) callback.on_train_end(trainer, model) if training: spy.assert_called()
def test_trace_example_input_array(self, model, trainer, path, data): model.example_input_array = data callback = TorchScriptCallback(path, True) callback.on_train_end(trainer, model) assert os.path.isfile(path)
def test_exception_on_device_type_ellipsis(self, trainer, path): model = BadModel(10, 10, 3) callback = TorchScriptCallback(path) with pytest.raises(RuntimeError): callback.on_train_end(trainer, model)
def test_exported_trace_is_loadable(self, model, trainer, path, data): callback = TorchScriptCallback(path, True, data) callback.on_train_end(trainer, model) loaded = torch.jit.load(path) assert isinstance(loaded, ScriptModule)
def test_trace_exported(self, model, trainer, path, data): callback = TorchScriptCallback(path, True, data) callback.on_train_end(trainer, model) assert os.path.isfile(path)
def test_script_exported(self, model, trainer, path): callback = TorchScriptCallback(path) callback.on_train_end(trainer, model) assert os.path.isfile(path)