def test_pipeline_def_tune_works(pipeline_def, monkeypatch): # Arrange mock_schema = mock.MagicMock() pipeline_def.schema_gen = mock_schema mock_transform = mock.MagicMock() pipeline_def.transform = mock_transform mock_module = mock.Mock() mock_tuner = mock.Mock() monkeypatch.setattr('fluent_tfx.pipeline_def.Tuner', mock_tuner) # Act pipeline_def = pipeline_def.tune(module_file=mock_module, train_args=1, eval_args=2) _, kwargs = mock_tuner.call_args # Assert assert mock_tuner.called assert kwargs['module_file'] is mock_module assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def) assert kwargs['train_args'] == 1 assert kwargs['eval_args'] == 2 assert kwargs['transform_graph'] == mock_transform.outputs[ 'transform_graph'] assert kwargs['examples'] is ExampleInputs.PREPROCESSED_EXAMPLES( pipeline_def) assert pipeline_def.components['tuner'] is mock_tuner.return_value assert pipeline_def.tuner is mock_tuner.return_value
def test_pipeline_def_preprocess_works(pipeline_def, monkeypatch): # Arrange mock_examples = mock.MagicMock() pipeline_def.example_gen = mock_examples mock_schema = mock.MagicMock() pipeline_def.schema_gen = mock_schema mock_transform = mock.Mock() monkeypatch.setattr('fluent_tfx.pipeline_def.Transform', mock_transform) mock_module = mock.Mock() # Act pipeline_def = pipeline_def.preprocess(mock_module) _, kwargs = mock_transform.call_args # Assert assert mock_transform.called assert kwargs['examples'] is mock_examples.outputs['examples'] assert kwargs['module_file'] is mock_module assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def) assert kwargs['materialize'] == True assert pipeline_def.components['transform'] is mock_transform.return_value assert pipeline_def.transform is mock_transform.return_value
def test_schema_channel_inputs_returns_user_provided_if_no_schema_gen(): # Arrange mock_pipeline_def = mock.MagicMock() mock_pipeline_def.user_schema_importer.outputs = {'result': 'schema'} # Act hparams = SchemaInputs.SCHEMA_CHANNEL(mock_pipeline_def) # Assert assert hparams == 'schema'
def test_pipeline_def_train_works_with_optional_args(pipeline_def, monkeypatch): # Arrange mock_schema = mock.MagicMock() pipeline_def.schema_gen = mock_schema mock_hparams = mock.MagicMock() pipeline_def.user_hyperparameters_importer = mock_hparams mock_transform = mock.MagicMock() pipeline_def.transform = mock_transform mock_example_input = mock.MagicMock() pipeline_def.example_gen = mock_example_input pipeline_def.cached_example_input = None pipeline_def.tuner = None mock_trainer = mock.Mock() monkeypatch.setattr('fluent_tfx.pipeline_def.Trainer', mock_trainer) mock_module = mock.Mock() mock_config = mock.Mock() mock_executor_spec = mock.Mock() # Act pipeline_def = pipeline_def.train(module_file=mock_module, train_args=1, eval_args=2, custom_config=mock_config, custom_executor_spec=mock_executor_spec, example_input=ExampleInputs.RAW_EXAMPLES) _, kwargs = mock_trainer.call_args # Assert assert mock_trainer.called assert kwargs['train_args'] == 1 assert kwargs['eval_args'] == 2 assert kwargs['transformed_examples'] is ExampleInputs.RAW_EXAMPLES( pipeline_def) assert pipeline_def.cached_example_input is ExampleInputs.RAW_EXAMPLES( pipeline_def) assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def) assert kwargs['transform_graph'] is mock_transform.outputs[ 'transform_graph'] assert kwargs['custom_config'] is mock_config assert kwargs['custom_executor_spec'] is mock_executor_spec assert kwargs[ 'hyperparameters'] is HyperParameterInputs.BEST_HYPERPARAMETERS( pipeline_def) assert pipeline_def.components['trainer'] is mock_trainer.return_value assert pipeline_def.trainer is mock_trainer.return_value
def test_pipeline_def_train_minimally_works(pipeline_def, monkeypatch): # Arrange mock_schema = mock.MagicMock() pipeline_def.schema_gen = mock_schema mock_transform = mock.MagicMock() pipeline_def.transform = mock_transform pipeline_def.cached_example_input = None pipeline_def.tuner = None pipeline_def.user_hyperparameters_importer = None mock_trainer = mock.Mock() monkeypatch.setattr('fluent_tfx.pipeline_def.Trainer', mock_trainer) mock_module = mock.Mock() # Act pipeline_def = pipeline_def.train(module_file=mock_module) _, kwargs = mock_trainer.call_args # Assert assert mock_trainer.called assert 'train_args' not in kwargs assert 'eval_args' not in kwargs assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def) assert kwargs['transform_graph'] is mock_transform.outputs[ 'transform_graph'] assert 'custom_config' not in kwargs assert 'hyperparameters' not in kwargs assert isinstance(kwargs['custom_executor_spec'], executor_spec.ExecutorClassSpec) assert kwargs[ 'transformed_examples'] is ExampleInputs.PREPROCESSED_EXAMPLES( pipeline_def) assert pipeline_def.components['trainer'] is mock_trainer.return_value assert pipeline_def.trainer is mock_trainer.return_value
def test_pipeline_def_validate_input_data_works(pipeline_def, monkeypatch): # Arrange mock_statistics_gen = mock.MagicMock() pipeline_def.statistics_gen = mock_statistics_gen mock_schema_gen = mock.MagicMock() pipeline_def.schema_gen = mock_schema_gen mock_example_validator = mock.Mock() monkeypatch.setattr('fluent_tfx.pipeline_def.ExampleValidator', mock_example_validator) # Act pipeline_def = pipeline_def.validate_input_data() _, kwargs = mock_example_validator.call_args # Assert assert mock_example_validator.called assert kwargs['statistics'] is mock_statistics_gen.outputs['statistics'] assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def) assert 'exclude_splits' not in kwargs assert pipeline_def.components[ 'example_validator'] is mock_example_validator.return_value assert pipeline_def.example_validator is mock_example_validator.return_value