def test_lambda_callback_value_error(stage): msg = "LambdaCallback should fail if {} gets wrong # of arguments.".format(stage) msg_wrong_type = "LambdaCallback should fail if {} isn't a function or None.".format( stage ) with pytest.raises(ValueError, message=msg): kwargs = {stage: lambda nn_state, epoch, batch, extra: "foobar"} LambdaCallback(**kwargs) with pytest.raises(ValueError, message=msg_wrong_type): LambdaCallback(**{stage: "foobar"})
def test_lambda_callback_type_error(stage): msg = "LambdaCallback should fail if {} isn't a function or None.".format( stage) with pytest.raises(TypeError): LambdaCallback(**{stage: "foobar"}) pytest.fail(msg)
def test_lambda_callback_value_error_num_args(stage): msg = f"LambdaCallback should fail if {stage} gets wrong # of arguments." with pytest.raises(ValueError): kwargs = {stage: lambda nn_state, epoch, batch, extra: "foobar"} LambdaCallback(**kwargs) pytest.fail(msg)
def test_stop_training_in_epoch(gpu): qucumber.set_random_seed(SEED, cpu=True, gpu=gpu, quiet=True) nn_state = PositiveWaveFunction(10, gpu=gpu) data = torch.ones(100, 10) callbacks = [ LambdaCallback( on_epoch_end=lambda nn_state, ep: set_stop_training(nn_state)) ] nn_state.fit(data, callbacks=callbacks) msg = "stop_training wasn't set!" assert nn_state.stop_training, msg