コード例 #1
0
def test_torchscript_retain_training_state():
    """ Test that torchscript export does not alter the training mode of original model. """
    model = EvalModelTemplate()
    model.train(True)
    script = model.to_torchscript()
    assert model.training
    assert not script.training
    model.train(False)
    _ = model.to_torchscript()
    assert not model.training
    assert not script.training
コード例 #2
0
def test_torchscript_device(device):
    """ Test that scripted module is on the correct device. """
    model = EvalModelTemplate().to(device)
    script = model.to_torchscript()
    assert next(script.parameters()).device == device
    script_output = script(model.example_input_array.to(device))
    assert script_output.device == device