Esempio n. 1
0
def test_pipeline_def_tune_works(pipeline_def, monkeypatch):
    # Arrange
    mock_schema = mock.MagicMock()
    pipeline_def.schema_gen = mock_schema

    mock_transform = mock.MagicMock()
    pipeline_def.transform = mock_transform

    mock_module = mock.Mock()

    mock_tuner = mock.Mock()
    monkeypatch.setattr('fluent_tfx.pipeline_def.Tuner', mock_tuner)

    # Act
    pipeline_def = pipeline_def.tune(module_file=mock_module,
                                     train_args=1,
                                     eval_args=2)

    _, kwargs = mock_tuner.call_args

    # Assert
    assert mock_tuner.called
    assert kwargs['module_file'] is mock_module
    assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def)
    assert kwargs['train_args'] == 1
    assert kwargs['eval_args'] == 2
    assert kwargs['transform_graph'] == mock_transform.outputs[
        'transform_graph']
    assert kwargs['examples'] is ExampleInputs.PREPROCESSED_EXAMPLES(
        pipeline_def)
    assert pipeline_def.components['tuner'] is mock_tuner.return_value
    assert pipeline_def.tuner is mock_tuner.return_value
Esempio n. 2
0
def test_pipeline_def_preprocess_works(pipeline_def, monkeypatch):
    # Arrange
    mock_examples = mock.MagicMock()
    pipeline_def.example_gen = mock_examples

    mock_schema = mock.MagicMock()
    pipeline_def.schema_gen = mock_schema

    mock_transform = mock.Mock()
    monkeypatch.setattr('fluent_tfx.pipeline_def.Transform', mock_transform)

    mock_module = mock.Mock()

    # Act
    pipeline_def = pipeline_def.preprocess(mock_module)
    _, kwargs = mock_transform.call_args

    # Assert
    assert mock_transform.called
    assert kwargs['examples'] is mock_examples.outputs['examples']
    assert kwargs['module_file'] is mock_module
    assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def)
    assert kwargs['materialize'] == True
    assert pipeline_def.components['transform'] is mock_transform.return_value
    assert pipeline_def.transform is mock_transform.return_value
Esempio n. 3
0
def test_schema_channel_inputs_returns_user_provided_if_no_schema_gen():
    # Arrange
    mock_pipeline_def = mock.MagicMock()
    mock_pipeline_def.user_schema_importer.outputs = {'result': 'schema'}

    # Act
    hparams = SchemaInputs.SCHEMA_CHANNEL(mock_pipeline_def)

    # Assert
    assert hparams == 'schema'
Esempio n. 4
0
def test_pipeline_def_train_works_with_optional_args(pipeline_def,
                                                     monkeypatch):
    # Arrange
    mock_schema = mock.MagicMock()
    pipeline_def.schema_gen = mock_schema

    mock_hparams = mock.MagicMock()
    pipeline_def.user_hyperparameters_importer = mock_hparams

    mock_transform = mock.MagicMock()
    pipeline_def.transform = mock_transform

    mock_example_input = mock.MagicMock()
    pipeline_def.example_gen = mock_example_input
    pipeline_def.cached_example_input = None

    pipeline_def.tuner = None

    mock_trainer = mock.Mock()
    monkeypatch.setattr('fluent_tfx.pipeline_def.Trainer', mock_trainer)

    mock_module = mock.Mock()
    mock_config = mock.Mock()
    mock_executor_spec = mock.Mock()

    # Act
    pipeline_def = pipeline_def.train(module_file=mock_module,
                                      train_args=1,
                                      eval_args=2,
                                      custom_config=mock_config,
                                      custom_executor_spec=mock_executor_spec,
                                      example_input=ExampleInputs.RAW_EXAMPLES)
    _, kwargs = mock_trainer.call_args

    # Assert
    assert mock_trainer.called
    assert kwargs['train_args'] == 1
    assert kwargs['eval_args'] == 2
    assert kwargs['transformed_examples'] is ExampleInputs.RAW_EXAMPLES(
        pipeline_def)
    assert pipeline_def.cached_example_input is ExampleInputs.RAW_EXAMPLES(
        pipeline_def)
    assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def)
    assert kwargs['transform_graph'] is mock_transform.outputs[
        'transform_graph']
    assert kwargs['custom_config'] is mock_config
    assert kwargs['custom_executor_spec'] is mock_executor_spec
    assert kwargs[
        'hyperparameters'] is HyperParameterInputs.BEST_HYPERPARAMETERS(
            pipeline_def)
    assert pipeline_def.components['trainer'] is mock_trainer.return_value
    assert pipeline_def.trainer is mock_trainer.return_value
Esempio n. 5
0
def test_pipeline_def_train_minimally_works(pipeline_def, monkeypatch):
    # Arrange
    mock_schema = mock.MagicMock()
    pipeline_def.schema_gen = mock_schema

    mock_transform = mock.MagicMock()
    pipeline_def.transform = mock_transform

    pipeline_def.cached_example_input = None

    pipeline_def.tuner = None
    pipeline_def.user_hyperparameters_importer = None

    mock_trainer = mock.Mock()
    monkeypatch.setattr('fluent_tfx.pipeline_def.Trainer', mock_trainer)

    mock_module = mock.Mock()

    # Act
    pipeline_def = pipeline_def.train(module_file=mock_module)
    _, kwargs = mock_trainer.call_args

    # Assert
    assert mock_trainer.called
    assert 'train_args' not in kwargs
    assert 'eval_args' not in kwargs
    assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def)
    assert kwargs['transform_graph'] is mock_transform.outputs[
        'transform_graph']
    assert 'custom_config' not in kwargs
    assert 'hyperparameters' not in kwargs
    assert isinstance(kwargs['custom_executor_spec'],
                      executor_spec.ExecutorClassSpec)
    assert kwargs[
        'transformed_examples'] is ExampleInputs.PREPROCESSED_EXAMPLES(
            pipeline_def)
    assert pipeline_def.components['trainer'] is mock_trainer.return_value
    assert pipeline_def.trainer is mock_trainer.return_value
Esempio n. 6
0
def test_pipeline_def_validate_input_data_works(pipeline_def, monkeypatch):
    # Arrange
    mock_statistics_gen = mock.MagicMock()
    pipeline_def.statistics_gen = mock_statistics_gen
    mock_schema_gen = mock.MagicMock()
    pipeline_def.schema_gen = mock_schema_gen

    mock_example_validator = mock.Mock()
    monkeypatch.setattr('fluent_tfx.pipeline_def.ExampleValidator',
                        mock_example_validator)

    # Act
    pipeline_def = pipeline_def.validate_input_data()
    _, kwargs = mock_example_validator.call_args

    # Assert
    assert mock_example_validator.called
    assert kwargs['statistics'] is mock_statistics_gen.outputs['statistics']
    assert kwargs['schema'] is SchemaInputs.SCHEMA_CHANNEL(pipeline_def)
    assert 'exclude_splits' not in kwargs
    assert pipeline_def.components[
        'example_validator'] is mock_example_validator.return_value
    assert pipeline_def.example_validator is mock_example_validator.return_value