def test_build_train_spec(modules, trainer): tensor_dict = {'inputs': ['faketensor']} labels = ['fakelabels'] # We add some defaulted hyperparameters into the customer params. expected_hps = HYPERPARAMETERS.copy() expected_hps['save_checkpoints_secs'] = 300 # Set up "customer script". def customer_train_input_fn(training_dir, hyperparameters): assert training_dir == TRAIN_DIR assert hyperparameters == expected_hps return tensor_dict, labels customer_script = EmptyModule() customer_script.train_input_fn = customer_train_input_fn trainer.train_steps = 987 trainer.customer_script = customer_script spec = trainer._build_train_spec() modules.estimator.TrainSpec.assert_called_with(ANY, max_steps=987) assert modules.estimator.TrainSpec.return_value == spec # Assert that we passed a 0-arg function to TrainSpec as the train_input_fn, that when called, # Invokes the customer's train_input_fn with the correct training_dir and hyperparameters. train_input_fn = modules.estimator.TrainSpec.call_args[0][0] returned_dict, returned_labels = train_input_fn() assert (tensor_dict, labels) == (returned_dict, returned_labels)
def test_build_train_spec_input_channels(modules, trainer): tensor_dict = {'inputs': ['faketensor']} labels = ['fakelabels'] # We add some defaulted hyperparameters into the customer params. expected_hps = HYPERPARAMETERS.copy() expected_hps['save_checkpoints_secs'] = 300 # Set up "customer script". def customer_train_input_fn(input_channels, hyperparameters): assert input_channels == INPUT_CHANNELS assert hyperparameters == expected_hps return tensor_dict, labels customer_script = EmptyModule() customer_script.train_input_fn = customer_train_input_fn trainer.customer_script = customer_script spec = trainer._build_train_spec() train_input_fn = modules.estimator.TrainSpec.call_args[0][0] returned_dict, returned_labels = train_input_fn() assert (tensor_dict, labels) == (returned_dict, returned_labels)