コード例 #1
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
 def test_module_training_state_unchanged(self, model, trainer, path, training):
     if training:
         model.train()
     else:
         model.eval()
     callback = TorchScriptCallback(path)
     callback.on_train_end(trainer, model)
     assert model.training == training
コード例 #2
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
    def test_module_set_to_eval_mode(self, model, trainer, path, mocker, training):
        if training:
            model.train()
        else:
            model.eval()
        spy = mocker.spy(model, "eval")
        callback = TorchScriptCallback(path)
        callback.on_train_end(trainer, model)

        if training:
            spy.assert_called()
コード例 #3
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
 def test_trace_example_input_array(self, model, trainer, path, data):
     model.example_input_array = data
     callback = TorchScriptCallback(path, True)
     callback.on_train_end(trainer, model)
     assert os.path.isfile(path)
コード例 #4
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
 def test_exception_on_device_type_ellipsis(self, trainer, path):
     model = BadModel(10, 10, 3)
     callback = TorchScriptCallback(path)
     with pytest.raises(RuntimeError):
         callback.on_train_end(trainer, model)
コード例 #5
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
 def test_exported_trace_is_loadable(self, model, trainer, path, data):
     callback = TorchScriptCallback(path, True, data)
     callback.on_train_end(trainer, model)
     loaded = torch.jit.load(path)
     assert isinstance(loaded, ScriptModule)
コード例 #6
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
 def test_trace_exported(self, model, trainer, path, data):
     callback = TorchScriptCallback(path, True, data)
     callback.on_train_end(trainer, model)
     assert os.path.isfile(path)
コード例 #7
0
ファイル: test_other.py プロジェクト: TidalPaladin/combustion
 def test_script_exported(self, model, trainer, path):
     callback = TorchScriptCallback(path)
     callback.on_train_end(trainer, model)
     assert os.path.isfile(path)