def test_user_model_fn(modules, trainer):
    fake_run_config = 'fakerunconfig'
    fake_model_fn = 'fakemodelfn'
    expected_hps = trainer.customer_params.copy()
    customer_script = EmptyModule()
    customer_script.model_fn = fake_model_fn
    trainer.customer_script = customer_script

    estimator = trainer._build_estimator(fake_run_config)

    estimator_mock = modules.estimator.Estimator
    estimator_mock.assert_called_with(model_fn=fake_model_fn,
                                      params=expected_hps,
                                      config=fake_run_config)
    assert estimator == estimator_mock.return_value
def test_user_estimator_fn(trainer):
    fake_run_config = 'fakerunconfig'
    fake_estimator = 'fakeestimator'
    expected_hps = trainer.customer_params.copy()
    # Set up "customer script".
    def customer_estimator_fn(run_config, hyperparameters):
        assert run_config == fake_run_config
        assert hyperparameters == expected_hps
        return fake_estimator

    customer_script = EmptyModule()
    customer_script.estimator_fn = customer_estimator_fn
    trainer.customer_script = customer_script

    estimator = trainer._build_estimator(fake_run_config)

    assert estimator == fake_estimator
def test_user_model_fn(modules, trainer):
    fake_run_config = 'fakerunconfig'
    fake_model_fn = MagicMock(name='fake_model_fn')
    expected_hps = trainer.customer_params.copy()
    customer_script = EmptyModule()
    customer_script.model_fn = fake_model_fn
    trainer.customer_script = customer_script

    estimator = trainer._build_estimator(fake_run_config)

    estimator_mock = modules.estimator.Estimator
    # Verify that _model_fn passed to Estimator correctly passes args through to user script model_fn 
    estimator_mock.assert_called_with(model_fn=ANY, params=expected_hps, config=fake_run_config)
    _, kwargs, = estimator_mock.call_args
    kwargs['model_fn'](1, 2, 3, 4)
    fake_model_fn.assert_called_with(1, 2, 3, 4)
    # Verify that the created Estimator object is returned from _build_estimator
    assert estimator == estimator_mock.return_value
def test_user_keras_model_fn(modules, trainer):
    fake_run_config = 'fakerunconfig'
    fake_keras_model = 'fakekerasmodel'
    expected_hps = trainer.customer_params.copy()
    # Set up "customer script".
    def customer_keras_model_fn(hyperparameters):
        assert hyperparameters == expected_hps
        return fake_keras_model

    customer_script = EmptyModule()
    customer_script.keras_model_fn = customer_keras_model_fn
    trainer.customer_script = customer_script

    estimator = trainer._build_estimator(fake_run_config)

    model_to_estimator = modules.keras.estimator.model_to_estimator
    model_to_estimator.assert_called_with(keras_model=fake_keras_model, config=fake_run_config)
    assert estimator == model_to_estimator.return_value