def test_build_eval_spec_no_serving(modules, trainer): # Set up "customer script". tensor_dict = {'inputs': ['faketensor']} labels = ['fakelabels'] expected_hps = trainer.customer_params.copy() def customer_eval_input_fn(training_dir, params): assert training_dir == TRAIN_DIR assert params == expected_hps return tensor_dict, labels customer_script = EmptyModule() customer_script.eval_input_fn = customer_eval_input_fn trainer.customer_script = customer_script spec = trainer._build_eval_spec() evalspec_mock = modules.estimator.EvalSpec # eval_steps not specified by customer, use default of 100. # serving_input_fn not specified by customer, don't provide an exporter. evalspec_mock.assert_called_with(ANY, steps=100, exporters=None) args, _ = evalspec_mock.call_args # Assert the customer's eval_input_fn is used correctly eval_input_fn = args[0] returned_dict, returned_labels = eval_input_fn() assert (tensor_dict, labels) == (returned_dict, returned_labels)
def test_build_eval_spec_with_serving(modules, trainer): # Special hyperparameters passed in by customer should be passed to EvalSpec eval_params = {'throttle_secs': 13, 'start_delay_secs': 56} trainer.customer_params.update(eval_params) expected_hps = trainer.customer_params.copy() # Set up "customer script". tensor_dict = {'inputs': ['faketensor']} labels = ['fakelabels'] def customer_eval_input_fn(training_dir, params): assert training_dir == TRAIN_DIR assert params == expected_hps return tensor_dict, labels input_receiver = 'fakeservinginputreceiver' def customer_serving_input_fn(params): assert params == expected_hps return input_receiver customer_script = EmptyModule() customer_script.eval_input_fn = customer_eval_input_fn customer_script.serving_input_fn = customer_serving_input_fn trainer.customer_script = customer_script # Set a non-default eval_steps, which should be propaated through to the EvalSpec trainer.eval_steps = 567 spec = trainer._build_eval_spec() exporter_mock = modules.estimator.LatestExporter exporter_mock.assert_called_with('Servo', serving_input_receiver_fn=ANY) _, kwargs = exporter_mock.call_args serving_input_fn = kwargs['serving_input_receiver_fn'] returned_input_receiver = serving_input_fn() assert input_receiver == returned_input_receiver evalspec_mock = modules.estimator.EvalSpec evalspec_mock.assert_called_with(ANY, steps=567, exporters=ANY, throttle_secs=13, start_delay_secs=56) args, kwargs = evalspec_mock.call_args # Assert the customer's eval_input_fn is used correctly eval_input_fn = args[0] returned_dict, returned_labels = eval_input_fn() assert (tensor_dict, labels) == (returned_dict, returned_labels) # Assert the created LatestExporter is passed correctly to the EvalSpec assert exporter_mock.return_value == kwargs['exporters'] # Assert the created EvalSpec is returned from _build_eval_spec assert evalspec_mock.return_value == spec
def test_build_eval_spec_input_channels(modules, trainer): # Set up "customer script". tensor_dict = {'inputs': ['faketensor']} labels = ['fakelabels'] expected_hps = trainer.customer_params.copy() def customer_eval_input_fn(input_channels, params): assert input_channels == INPUT_CHANNELS assert params == expected_hps return tensor_dict, labels customer_script = EmptyModule() customer_script.eval_input_fn = customer_eval_input_fn trainer.customer_script = customer_script spec = trainer._build_eval_spec() evalspec_mock = modules.estimator.EvalSpec args, _ = evalspec_mock.call_args # Assert the customer's eval_input_fn is used correctly eval_input_fn = args[0] returned_dict, returned_labels = eval_input_fn() assert (tensor_dict, labels) == (returned_dict, returned_labels)