Example #1
0
def test_terminate_on_nan_and_inf(state_output, should_terminate):

    torch.manual_seed(12)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)
    trainer.state = State()
    h = TerminateOnNan()

    trainer.state.output = state_output
    if isinstance(state_output, np.ndarray):
        h._output_transform = lambda x: x.tolist()
    h(trainer)
    assert trainer.should_terminate == should_terminate
Example #2
0
def test_terminate_on_nan_and_inf():

    torch.manual_seed(12)

    def update_fn(engine, batch):
        pass

    trainer = Engine(update_fn)
    trainer.state = State()
    h = TerminateOnNan()

    trainer.state.output = 1.0
    h(trainer)
    assert not trainer.should_terminate

    trainer.state.output = torch.tensor(123.45)
    h(trainer)
    assert not trainer.should_terminate

    trainer.state.output = torch.asin(torch.randn(10, ))
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = np.array([1.0, 2.0])
    h._output_transform = lambda x: x.tolist()
    h(trainer)
    assert not trainer.should_terminate
    h._output_transform = lambda x: x

    trainer.state.output = torch.asin(torch.randn(4, 4))
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = (10.0, 1.0 /
                            torch.randint(0, 2, size=(4, )).type(torch.float),
                            1.0)
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = (1.0, torch.tensor(1.0), "abc")
    h(trainer)
    assert not trainer.should_terminate

    trainer.state.output = 1.0 / torch.randint(0, 2, size=(4, 4)).type(
        torch.float)
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = (float("nan"), 10.0)
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = float("inf")
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False

    trainer.state.output = [float("nan"), 10.0]
    h(trainer)
    assert trainer.should_terminate
    trainer.should_terminate = False