def test_user_model_fn(modules, trainer): fake_run_config = 'fakerunconfig' fake_model_fn = 'fakemodelfn' expected_hps = trainer.customer_params.copy() customer_script = EmptyModule() customer_script.model_fn = fake_model_fn trainer.customer_script = customer_script estimator = trainer._build_estimator(fake_run_config) estimator_mock = modules.estimator.Estimator estimator_mock.assert_called_with(model_fn=fake_model_fn, params=expected_hps, config=fake_run_config) assert estimator == estimator_mock.return_value
def test_user_estimator_fn(trainer): fake_run_config = 'fakerunconfig' fake_estimator = 'fakeestimator' expected_hps = trainer.customer_params.copy() # Set up "customer script". def customer_estimator_fn(run_config, hyperparameters): assert run_config == fake_run_config assert hyperparameters == expected_hps return fake_estimator customer_script = EmptyModule() customer_script.estimator_fn = customer_estimator_fn trainer.customer_script = customer_script estimator = trainer._build_estimator(fake_run_config) assert estimator == fake_estimator
def test_user_model_fn(modules, trainer): fake_run_config = 'fakerunconfig' fake_model_fn = MagicMock(name='fake_model_fn') expected_hps = trainer.customer_params.copy() customer_script = EmptyModule() customer_script.model_fn = fake_model_fn trainer.customer_script = customer_script estimator = trainer._build_estimator(fake_run_config) estimator_mock = modules.estimator.Estimator # Verify that _model_fn passed to Estimator correctly passes args through to user script model_fn estimator_mock.assert_called_with(model_fn=ANY, params=expected_hps, config=fake_run_config) _, kwargs, = estimator_mock.call_args kwargs['model_fn'](1, 2, 3, 4) fake_model_fn.assert_called_with(1, 2, 3, 4) # Verify that the created Estimator object is returned from _build_estimator assert estimator == estimator_mock.return_value
def test_user_keras_model_fn(modules, trainer): fake_run_config = 'fakerunconfig' fake_keras_model = 'fakekerasmodel' expected_hps = trainer.customer_params.copy() # Set up "customer script". def customer_keras_model_fn(hyperparameters): assert hyperparameters == expected_hps return fake_keras_model customer_script = EmptyModule() customer_script.keras_model_fn = customer_keras_model_fn trainer.customer_script = customer_script estimator = trainer._build_estimator(fake_run_config) model_to_estimator = modules.keras.estimator.model_to_estimator model_to_estimator.assert_called_with(keras_model=fake_keras_model, config=fake_run_config) assert estimator == model_to_estimator.return_value