def test_build_train_spec(modules, trainer):
    tensor_dict = {'inputs': ['faketensor']}
    labels = ['fakelabels']
    # We add some defaulted hyperparameters into the customer params.
    expected_hps = HYPERPARAMETERS.copy()
    expected_hps['save_checkpoints_secs'] = 300

    # Set up "customer script".
    def customer_train_input_fn(training_dir, hyperparameters):
        assert training_dir == TRAIN_DIR
        assert hyperparameters == expected_hps
        return tensor_dict, labels

    customer_script = EmptyModule()
    customer_script.train_input_fn = customer_train_input_fn

    trainer.train_steps = 987
    trainer.customer_script = customer_script

    spec = trainer._build_train_spec()

    modules.estimator.TrainSpec.assert_called_with(ANY, max_steps=987)
    assert modules.estimator.TrainSpec.return_value == spec
    # Assert that we passed a 0-arg function to TrainSpec as the train_input_fn, that when called,
    # Invokes the customer's train_input_fn with the correct training_dir and hyperparameters.
    train_input_fn = modules.estimator.TrainSpec.call_args[0][0]
    returned_dict, returned_labels = train_input_fn()
    assert (tensor_dict, labels) == (returned_dict, returned_labels)
def test_build_train_spec_input_channels(modules, trainer):
    tensor_dict = {'inputs': ['faketensor']}
    labels = ['fakelabels']
    # We add some defaulted hyperparameters into the customer params.
    expected_hps = HYPERPARAMETERS.copy()
    expected_hps['save_checkpoints_secs'] = 300

    # Set up "customer script".
    def customer_train_input_fn(input_channels, hyperparameters):
        assert input_channels == INPUT_CHANNELS
        assert hyperparameters == expected_hps
        return tensor_dict, labels
    customer_script = EmptyModule()
    customer_script.train_input_fn = customer_train_input_fn

    trainer.customer_script = customer_script

    spec = trainer._build_train_spec()

    train_input_fn = modules.estimator.TrainSpec.call_args[0][0]
    returned_dict, returned_labels = train_input_fn()
    assert (tensor_dict, labels) == (returned_dict, returned_labels)