def test_build_eval_spec_no_serving(modules, trainer):
    # Set up "customer script".
    tensor_dict = {'inputs': ['faketensor']}
    labels = ['fakelabels']
    expected_hps = trainer.customer_params.copy()

    def customer_eval_input_fn(training_dir, params):
        assert training_dir == TRAIN_DIR
        assert params == expected_hps
        return tensor_dict, labels

    customer_script = EmptyModule()
    customer_script.eval_input_fn = customer_eval_input_fn
    trainer.customer_script = customer_script

    spec = trainer._build_eval_spec()

    evalspec_mock = modules.estimator.EvalSpec
    # eval_steps not specified by customer, use default of 100.
    # serving_input_fn not specified by customer, don't provide an exporter.
    evalspec_mock.assert_called_with(ANY, steps=100, exporters=None)
    args, _ = evalspec_mock.call_args
    # Assert the customer's eval_input_fn is used correctly
    eval_input_fn = args[0]
    returned_dict, returned_labels = eval_input_fn()
    assert (tensor_dict, labels) == (returned_dict, returned_labels)
def test_build_eval_spec_with_serving(modules, trainer):
    # Special hyperparameters passed in by customer should be passed to EvalSpec
    eval_params = {'throttle_secs': 13, 'start_delay_secs': 56}
    trainer.customer_params.update(eval_params)
    expected_hps = trainer.customer_params.copy()

    # Set up "customer script".
    tensor_dict = {'inputs': ['faketensor']}
    labels = ['fakelabels']

    def customer_eval_input_fn(training_dir, params):
        assert training_dir == TRAIN_DIR
        assert params == expected_hps
        return tensor_dict, labels

    input_receiver = 'fakeservinginputreceiver'

    def customer_serving_input_fn(params):
        assert params == expected_hps
        return input_receiver

    customer_script = EmptyModule()
    customer_script.eval_input_fn = customer_eval_input_fn
    customer_script.serving_input_fn = customer_serving_input_fn
    trainer.customer_script = customer_script
    # Set a non-default eval_steps, which should be propaated through to the EvalSpec
    trainer.eval_steps = 567

    spec = trainer._build_eval_spec()

    exporter_mock = modules.estimator.LatestExporter
    exporter_mock.assert_called_with('Servo', serving_input_receiver_fn=ANY)
    _, kwargs = exporter_mock.call_args
    serving_input_fn = kwargs['serving_input_receiver_fn']
    returned_input_receiver = serving_input_fn()
    assert input_receiver == returned_input_receiver

    evalspec_mock = modules.estimator.EvalSpec
    evalspec_mock.assert_called_with(ANY,
                                     steps=567,
                                     exporters=ANY,
                                     throttle_secs=13,
                                     start_delay_secs=56)
    args, kwargs = evalspec_mock.call_args
    # Assert the customer's eval_input_fn is used correctly
    eval_input_fn = args[0]
    returned_dict, returned_labels = eval_input_fn()
    assert (tensor_dict, labels) == (returned_dict, returned_labels)
    # Assert the created LatestExporter is passed correctly to the EvalSpec
    assert exporter_mock.return_value == kwargs['exporters']
    # Assert the created EvalSpec is returned from _build_eval_spec
    assert evalspec_mock.return_value == spec
def test_build_eval_spec_input_channels(modules, trainer):
    # Set up "customer script".
    tensor_dict = {'inputs': ['faketensor']}
    labels = ['fakelabels']
    expected_hps = trainer.customer_params.copy()
    def customer_eval_input_fn(input_channels, params):
        assert input_channels == INPUT_CHANNELS
        assert params == expected_hps
        return tensor_dict, labels
    customer_script = EmptyModule()
    customer_script.eval_input_fn = customer_eval_input_fn
    trainer.customer_script = customer_script

    spec = trainer._build_eval_spec()

    evalspec_mock = modules.estimator.EvalSpec
    args, _ = evalspec_mock.call_args
    # Assert the customer's eval_input_fn is used correctly
    eval_input_fn = args[0]
    returned_dict, returned_labels = eval_input_fn()
    assert (tensor_dict, labels) == (returned_dict, returned_labels)