def test_terminate_on_nan_and_inf(state_output, should_terminate): torch.manual_seed(12) def update_fn(engine, batch): pass trainer = Engine(update_fn) trainer.state = State() h = TerminateOnNan() trainer.state.output = state_output if isinstance(state_output, np.ndarray): h._output_transform = lambda x: x.tolist() h(trainer) assert trainer.should_terminate == should_terminate
def test_terminate_on_nan_and_inf(): torch.manual_seed(12) def update_fn(engine, batch): pass trainer = Engine(update_fn) trainer.state = State() h = TerminateOnNan() trainer.state.output = 1.0 h(trainer) assert not trainer.should_terminate trainer.state.output = torch.tensor(123.45) h(trainer) assert not trainer.should_terminate trainer.state.output = torch.asin(torch.randn(10, )) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = np.array([1.0, 2.0]) h._output_transform = lambda x: x.tolist() h(trainer) assert not trainer.should_terminate h._output_transform = lambda x: x trainer.state.output = torch.asin(torch.randn(4, 4)) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = (10.0, 1.0 / torch.randint(0, 2, size=(4, )).type(torch.float), 1.0) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = (1.0, torch.tensor(1.0), "abc") h(trainer) assert not trainer.should_terminate trainer.state.output = 1.0 / torch.randint(0, 2, size=(4, 4)).type( torch.float) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = (float("nan"), 10.0) h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = float("inf") h(trainer) assert trainer.should_terminate trainer.should_terminate = False trainer.state.output = [float("nan"), 10.0] h(trainer) assert trainer.should_terminate trainer.should_terminate = False