コード例 #1
0
 def test_construct(self):
     examples = types.TfxType(type_name='ExamplesPath')
     model_exports = types.TfxType(type_name='ModelExportPath')
     evaluator = component.Evaluator(
         examples=channel.as_channel([examples]),
         model_exports=channel.as_channel([model_exports]))
     self.assertEqual('ModelEvalPath', evaluator.outputs.output.type_name)
コード例 #2
0
    def test_do(self):
        input_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
            'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        input_base = types.TfxType(type_name='ExternalPath')
        input_base.uri = os.path.join(input_data_dir, 'external/csv/')
        input_dict = {'input-base': [input_base]}

        # Create output dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Run executor.
        csv_example_gen = executor.Executor()
        csv_example_gen.Do(input_dict, output_dict, {})

        # Check CSV example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
コード例 #3
0
ファイル: channel_test.py プロジェクト: zwcdp/tfx
 def test_channel_as_channel_success(self):
     instance_a = types.TfxType('MyTypeName')
     instance_b = types.TfxType('MyTypeName')
     chnl_original = channel.Channel(
         'MyTypeName', static_artifact_collection=[instance_a, instance_b])
     chnl_result = channel.as_channel(chnl_original)
     self.assertEqual(chnl_original, chnl_result)
コード例 #4
0
ファイル: executor_test.py プロジェクト: zwcdp/tfx
    def testDo(self, mock_client):
        # Mock query result schema for _BigQueryConverter.
        mock_client.return_value.query.return_value.result.return_value.schema = self._schema

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Run executor.
        big_query_example_gen = executor.Executor()
        big_query_example_gen.Do({}, output_dict, self._exec_properties)

        # Check BigQuery example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
コード例 #5
0
 def test_construct(self):
     train_examples = types.TfxType(type_name='ExamplesPath', split='train')
     eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
     statistics_gen = component.StatisticsGen(
         input_data=channel.as_channel([train_examples, eval_examples]))
     self.assertEqual('ExampleStatisticsPath',
                      statistics_gen.outputs.output.type_name)
コード例 #6
0
ファイル: executor_test.py プロジェクト: zwcdp/tfx
  def setUp(self):
    self._source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    self._output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)
    tf.gfile.MakeDirs(self._output_data_dir)
    self._model_export = types.TfxType(type_name='ModelExportPath')
    self._model_export.uri = os.path.join(self._source_data_dir,
                                          'trainer/current/')
    self._model_blessing = types.TfxType(type_name='ModelBlessingPath')
    self._input_dict = {
        'model_export': [self._model_export],
        'model_blessing': [self._model_blessing],
    }

    self._model_push = types.TfxType(type_name='ModelPushPath')
    self._model_push.uri = os.path.join(self._output_data_dir, 'model_push')
    tf.gfile.MakeDirs(self._model_push.uri)
    self._output_dict = {
        'model_push': [self._model_push],
    }
    self._serving_model_dir = os.path.join(self._output_data_dir,
                                           'serving_model_dir')
    tf.gfile.MakeDirs(self._serving_model_dir)
    self._exec_properties = {
        'push_destination':
            json_format.MessageToJson(
                pusher_pb2.PushDestination(
                    filesystem=pusher_pb2.PushDestination.Filesystem(
                        base_directory=self._serving_model_dir))),
    }
    self._executor = executor.Executor()
コード例 #7
0
ファイル: channel_test.py プロジェクト: zwcdp/tfx
 def test_valid_channel(self):
     instance_a = types.TfxType('MyTypeName')
     instance_b = types.TfxType('MyTypeName')
     chnl = channel.Channel(
         'MyTypeName', static_artifact_collection=[instance_a, instance_b])
     self.assertEqual(chnl.type_name, 'MyTypeName')
     self.assertItemsEqual(chnl.get(), [instance_a, instance_b])
コード例 #8
0
ファイル: metadata_test.py プロジェクト: rohithreddy/tfx
  def test_fetch_previous_result(self):
    with Metadata(
        connection_config=self._connection_config, logger=self._logger) as m:

      # Create an 'previous' execution.
      exec_properties = {'log_root': 'path'}
      eid = m.prepare_execution('Test', exec_properties)
      input_artifact = types.TfxType(type_name='ExamplesPath')
      m.publish_artifacts([input_artifact])
      output_artifact = types.TfxType(type_name='ExamplesPath')
      input_dict = {'input': [input_artifact]}
      output_dict = {'output': [output_artifact]}
      m.publish_execution(eid, input_dict, output_dict)

      # Test previous_run.
      self.assertEqual(None, m.previous_run('Test', input_dict, {}))
      self.assertEqual(None, m.previous_run('Test', {}, exec_properties))
      self.assertEqual(None, m.previous_run('Test2', input_dict,
                                            exec_properties))
      self.assertEqual(eid, m.previous_run('Test', input_dict, exec_properties))

      # Test fetch_previous_result_artifacts.
      new_output_artifact = types.TfxType(type_name='ExamplesPath')
      self.assertNotEqual(types.ARTIFACT_STATE_PUBLISHED,
                          new_output_artifact.state)
      new_output_dict = {'output': [new_output_artifact]}
      updated_output_dict = m.fetch_previous_result_artifacts(
          new_output_dict, eid)
      previous_artifact = output_dict['output'][-1].artifact
      current_artifact = updated_output_dict['output'][-1].artifact
      self.assertEqual(types.ARTIFACT_STATE_PUBLISHED,
                       current_artifact.properties['state'].string_value)
      self.assertEqual(previous_artifact.id, current_artifact.id)
      self.assertEqual(previous_artifact.type_id, current_artifact.type_id)
コード例 #9
0
ファイル: channel_test.py プロジェクト: zwcdp/tfx
 def test_invalid_channel_type(self):
     instance_a = types.TfxType('MyTypeName')
     instance_b = types.TfxType('MyTypeName')
     with self.assertRaises(ValueError):
         channel.Channel(
             'AnotherTypeName',
             static_artifact_collection=[instance_a, instance_b])
コード例 #10
0
  def setUp(self):
    self._source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
    eval_examples.uri = os.path.join(self._source_data_dir,
                                     'csv_example_gen/eval/')
    model = types.TfxType(type_name='ModelExportPath')
    model.uri = os.path.join(self._source_data_dir, 'trainer/current/')
    self._input_dict = {
        'examples': [eval_examples],
        'model': [model],
    }

    # Create output dict.
    self._blessing = types.TfxType('ModelBlessingPath')
    self._blessing.uri = os.path.join(output_data_dir, 'blessing')
    self._output_dict = {
        'blessing': [self._blessing]
    }

    # Create context
    self._tmp_dir = os.path.join(output_data_dir, '.temp')
    self._context = executor.Executor.Context(tmp_dir=self._tmp_dir,
                                              unique_id='2')
コード例 #11
0
ファイル: executor_test.py プロジェクト: zwcdp/tfx
  def test_do(self):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')

    train_stats_artifact = types.TfxType('ExampleStatsPath', split='train')
    train_stats_artifact.uri = os.path.join(source_data_dir,
                                            'statistics_gen/train/')

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    schema_output = types.TfxType('SchemaPath')
    schema_output.uri = os.path.join(output_data_dir, 'schema_output')

    input_dict = {
        'stats': [train_stats_artifact],
    }
    output_dict = {
        'output': [schema_output],
    }

    exec_properties = {}

    schema_gen_executor = executor.Executor()
    schema_gen_executor.Do(input_dict, output_dict, exec_properties)
    self.assertNotEqual(0, len(tf.gfile.ListDirectory(schema_output.uri)))
コード例 #12
0
ファイル: base_driver_test.py プロジェクト: zwcdp/tfx
 def setUp(self):
     self._mock_metadata = tf.test.mock.Mock()
     self._input_dict = {
         'input_data': [types.TfxType(type_name='InputType')],
     }
     input_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'input_dir')
     # valid input artifacts must have a uri pointing to an existing directory.
     for key, input_list in self._input_dict.items():
         for index, artifact in enumerate(input_list):
             artifact.id = index + 1
             uri = os.path.join(input_dir, key, str(artifact.id), '')
             artifact.uri = uri
             tf.gfile.MakeDirs(uri)
     self._output_dict = {
         'output_data': [types.TfxType(type_name='OutputType')],
     }
     self._exec_properties = {
         'key': 'value',
     }
     self._base_output_dir = os.path.join(
         os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
         self._testMethodName, 'base_output_dir')
     self._driver_options = base_driver.DriverOptions(
         worker_name='worker_name',
         base_output_dir=self._base_output_dir,
         enable_cache=True)
     self._execution_id = 100
     log_root = os.path.join(self._base_output_dir, 'log_dir')
     logger_config = logging_utils.LoggerConfig(log_root=log_root)
     self._logger = logging_utils.get_logger(logger_config)
コード例 #13
0
    def setUp(self):
        self._source_data_dir = os.path.join(
            os.path.dirname(
                os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
            'components', 'testdata')
        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        tf.gfile.MakeDirs(self._output_data_dir)
        self._model_export = types.TfxType(type_name='ModelExportPath')
        self._model_export.uri = os.path.join(self._source_data_dir,
                                              'trainer/current/')
        self._model_blessing = types.TfxType(type_name='ModelBlessingPath')
        self._input_dict = {
            'model_export': [self._model_export],
            'model_blessing': [self._model_blessing],
        }

        self._model_push = types.TfxType(type_name='ModelPushPath')
        self._model_push.uri = os.path.join(self._output_data_dir,
                                            'model_push')
        tf.gfile.MakeDirs(self._model_push.uri)
        self._output_dict = {
            'model_push': [self._model_push],
        }
        self._exec_properties = {
            'custom_config': {
                'ai_platform_serving_args': {
                    'model_name': 'model_name',
                    'project_id': 'project_id'
                },
            },
        }
        self._executor = Executor()
コード例 #14
0
    def test_do(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Run executor.
        example_gen = TestExampleGenExecutor()
        example_gen.Do({}, output_dict, {})

        # Check example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
コード例 #15
0
ファイル: executor_test.py プロジェクト: zwcdp/tfx
    def test_do(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(
            source_data_dir, 'transform/transformed_examples/train/')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(
            source_data_dir, 'transform/transformed_examples/eval/')
        transform_output = types.TfxType(type_name='TransformPath')
        transform_output.uri = os.path.join(source_data_dir,
                                            'transform/transform_output/')
        schema = types.TfxType(type_name='ExamplesPath')
        schema.uri = os.path.join(source_data_dir, 'schema_gen/')

        input_dict = {
            'transformed_examples': [train_examples, eval_examples],
            'transform_output': [transform_output],
            'schema': [schema],
        }

        # Create output dict.
        model_exports = types.TfxType(type_name='ModelExportPath')
        model_exports.uri = os.path.join(output_data_dir, 'model_export_path')
        output_dict = {'output': [model_exports]}

        # Create exec properties.
        module_file_path = os.path.join(source_data_dir, 'module_file',
                                        'trainer_module.py')

        exec_properties = {
            'train_args':
            json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=1000)),
            'eval_args':
            json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=500)),
            'module_file':
            module_file_path,
            'warm_starting':
            False,
        }

        # Run executor.
        pipeline = beam.Pipeline()
        evaluator = executor.Executor(pipeline)
        evaluator.Do(input_dict=input_dict,
                     output_dict=output_dict,
                     exec_properties=exec_properties)

        # Check outputs.
        self.assertTrue(
            tf.gfile.Exists(os.path.join(model_exports.uri, 'eval_model_dir')))
        self.assertTrue(
            tf.gfile.Exists(
                os.path.join(model_exports.uri, 'serving_model_dir')))
コード例 #16
0
 def test_construct(self):
     examples = types.TfxType(type_name='ExamplesPath')
     model = types.TfxType(type_name='ModelExportPath')
     model_validator = component.ModelValidator(
         examples=channel.as_channel([examples]),
         model=channel.as_channel([model]))
     self.assertEqual('ModelBlessingPath',
                      model_validator.outputs.blessing.type_name)
コード例 #17
0
ファイル: component_test.py プロジェクト: ashishML/tfx
 def test_construct(self):
     example_vadalitor = component.ExampleValidator(
         stats=channel.as_channel([
             types.TfxType(type_name='ExampleStatisticsPath', split='eval')
         ]),
         schema=channel.as_channel([types.TfxType(type_name='SchemaPath')]),
     )
     self.assertEqual('ExampleValidationPath',
                      example_vadalitor.outputs.output.type_name)
コード例 #18
0
ファイル: airflow_adapter_test.py プロジェクト: zhitaoli/tfx
 def setUp(self):
     self.input_one = types.TfxType('INPUT_ONE')
     self.input_one.source = airflow_component._OrchestrationSource(
         'input_one_key', 'input_one_component_id')
     self.output_one = types.TfxType('OUTPUT_ONE')
     self.output_one.source = airflow_component._OrchestrationSource(
         'output_one_key', 'output_one_component_id')
     self.input_one_json = json.dumps([self.input_one.json_dict()])
     self.output_one_json = json.dumps([self.output_one.json_dict()])
     self._logger_config = logging_utils.LoggerConfig()
コード例 #19
0
ファイル: metadata_test.py プロジェクト: zwcdp/tfx
  def test_execution(self):
    with Metadata(
        connection_config=self._connection_config,
        logger=self._logger) as m:

      # Test prepare_execution.
      exec_properties = {}
      eid = m.prepare_execution('Test', exec_properties)
      [execution] = m.store.get_executions()
      self.assertProtoEquals(
          """
        id: 1
        type_id: 1
        properties {
          key: "state"
          value {
            string_value: "new"
          }
        }""", execution)

      # Test publish_execution.
      input_artifact = types.TfxType(type_name='ExamplesPath')
      m.publish_artifacts([input_artifact])
      output_artifact = types.TfxType(type_name='ExamplesPath')
      input_dict = {'input': [input_artifact]}
      output_dict = {'output': [output_artifact]}
      m.publish_execution(eid, input_dict, output_dict)
      # Make sure artifacts in output_dict are published.
      self.assertEqual(types.ARTIFACT_STATE_PUBLISHED, output_artifact.state)
      # Make sure execution state are changed.
      [execution] = m.store.get_executions_by_id([eid])
      self.assertEqual('complete', execution.properties['state'].string_value)
      # Make sure events are published.
      events = m.store.get_events_by_execution_ids([eid])
      self.assertEqual(2, len(events))
      self.assertEqual(input_artifact.id, events[0].artifact_id)
      self.assertEqual(metadata_store_pb2.Event.DECLARED_INPUT, events[0].type)
      self.assertProtoEquals(
          """
          steps {
            key: "input"
          }
          steps {
            index: 0
          }""", events[0].path)
      self.assertEqual(output_artifact.id, events[1].artifact_id)
      self.assertEqual(metadata_store_pb2.Event.DECLARED_OUTPUT, events[1].type)
      self.assertProtoEquals(
          """
          steps {
            key: "output"
          }
          steps {
            index: 0
          }""", events[1].path)
コード例 #20
0
 def test_construct_with_slice_spec(self):
     examples = types.TfxType(type_name='ExamplesPath')
     model_exports = types.TfxType(type_name='ModelExportPath')
     evaluator = component.Evaluator(
         examples=channel.as_channel([examples]),
         model_exports=channel.as_channel([model_exports]),
         feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
             evaluator_pb2.SingleSlicingSpec(
                 column_for_slicing=['trip_start_hour'])
         ]))
     self.assertEqual('ModelEvalPath', evaluator.outputs.output.type_name)
コード例 #21
0
 def test_construct(self):
     transformed_examples = types.TfxType(type_name='ExamplesPath')
     transform_output = types.TfxType(type_name='TransformPath')
     schema = types.TfxType(type_name='SchemaPath')
     trainer = component.Trainer(
         module_file='/path/to/module/file',
         transformed_examples=channel.as_channel([transformed_examples]),
         transform_output=channel.as_channel([transform_output]),
         schema=channel.as_channel([schema]),
         train_args=trainer_pb2.TrainArgs(num_steps=100),
         eval_args=trainer_pb2.EvalArgs(num_steps=50))
     self.assertEqual('ModelExportPath', trainer.outputs.output.type_name)
コード例 #22
0
ファイル: executor_test.py プロジェクト: mwalenia/tfx
    def testDo(self, mock_client):
        # Mock query result schema for _BigQueryConverter.
        mock_client.return_value.query.return_value.result.return_value.schema = self._schema

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exe properties.
        exec_properties = {
            'input':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'output':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        big_query_example_gen = executor.Executor()
        big_query_example_gen.Do({}, output_dict, exec_properties)

        # Check BigQuery example gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())
コード例 #23
0
    def test_fetch_last_blessed_model(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        log_root = os.path.join(output_data_dir, 'log_root')

        # Mock metadata.
        mock_metadata = tf.test.mock.Mock()
        model_validator_driver = driver.Driver(log_root, mock_metadata)

        # No blessed model.
        mock_metadata.get_all_artifacts.return_value = []
        self.assertEqual((None, None),
                         model_validator_driver._fetch_last_blessed_model())

        # Mock blessing artifacts.
        artifacts = []
        for span in [4, 3, 2, 1]:
            model_blessing = types.TfxType(type_name='ModelBlessingPath')
            model_blessing.span = span
            model_blessing.set_string_custom_property('current_model',
                                                      'uri-%d' % span)
            model_blessing.set_int_custom_property('current_model_id', span)
            # Only odd spans are "blessed"
            model_blessing.set_int_custom_property('blessed', span % 2)
            artifacts.append(model_blessing.artifact)
        mock_metadata.get_all_artifacts.return_value = artifacts
        self.assertEqual(('uri-3', 3),
                         model_validator_driver._fetch_last_blessed_model())
コード例 #24
0
    def setUp(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        self._output_dict = {'examples': [train_examples, eval_examples]}

        self._train_output_file = os.path.join(
            train_examples.uri, 'data_tfrecord-00000-of-00001.gz')
        self._eval_output_file = os.path.join(
            eval_examples.uri, 'data_tfrecord-00000-of-00001.gz')
コード例 #25
0
ファイル: executor_test.py プロジェクト: zwcdp/tfx
    def test_do(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(source_data_dir,
                                         'csv_example_gen/eval/')
        model_exports = types.TfxType(type_name='ModelExportPath')
        model_exports.uri = os.path.join(source_data_dir, 'trainer/current/')
        input_dict = {
            'examples': [train_examples, eval_examples],
            'model_exports': [model_exports],
        }

        # Create output dict.
        eval_output = types.TfxType('ModelEvalPath')
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        output_dict = {'output': [eval_output]}

        # Create exec proterties.
        exec_properties = {
            'feature_slicing_spec':
            json_format.MessageToJson(
                evaluator_pb2.FeatureSlicingSpec(specs=[
                    evaluator_pb2.SingleSlicingSpec(
                        column_for_slicing=['trip_start_hour']),
                    evaluator_pb2.SingleSlicingSpec(
                        column_for_slicing=['trip_start_day', 'trip_miles']),
                ]))
        }

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            tf.gfile.Exists(os.path.join(eval_output.uri, 'eval_config')))
        self.assertTrue(
            tf.gfile.Exists(os.path.join(eval_output.uri, 'metrics')))
        self.assertTrue(tf.gfile.Exists(os.path.join(eval_output.uri,
                                                     'plots')))
コード例 #26
0
ファイル: executor_test.py プロジェクト: zwcdp/tfx
  def setUp(self):
    input_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'testdata')

    # Create input dict.
    input_base = types.TfxType(type_name='ExternalPath')
    input_base.uri = os.path.join(input_data_dir, 'external/csv/')
    self._input_dict = {'input-base': [input_base]}
コード例 #27
0
ファイル: types_test.py プロジェクト: zorrock/tfx
 def test_tfxtype_deprecated(self):
     with mock.patch.object(tf_logging, 'warning'):
         warn_mock = mock.MagicMock()
         tf_logging.warning = warn_mock
         types.TfxType('FakeType')
         warn_mock.assert_called_once()
         self.assertIn('TfxType has been renamed to TfxArtifact',
                       warn_mock.call_args[0][5])
コード例 #28
0
ファイル: component_test.py プロジェクト: zwcdp/tfx
 def test_construct(self):
     source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata',
                                    'taxi')
     preprocessing_fn_file = os.path.join(source_data_dir, 'module',
                                          'preprocess.py')
     transform = component.Transform(
         input_data=channel.as_channel([
             types.TfxType(type_name='ExamplesPath', split='train'),
             types.TfxType(type_name='ExamplesPath', split='eval'),
         ]),
         schema=channel.as_channel([types.TfxType(type_name='SchemaPath')]),
         module_file=preprocessing_fn_file,
     )
     self.assertEqual('TransformPath',
                      transform.outputs.transform_output.type_name)
     self.assertEqual('ExamplesPath',
                      transform.outputs.transformed_examples.type_name)
コード例 #29
0
  def test_tfx_type(self):
    instance = types.TfxType('MyTypeName', split='eval')

    # Test property getters.
    self.assertEqual('', instance.uri)
    self.assertEqual(0, instance.id)
    self.assertEqual(0, instance.type_id)
    self.assertEqual('MyTypeName', instance.type_name)
    self.assertIsNone(instance.state)
    self.assertEqual('eval', instance.split)
    self.assertIsNone(instance.span)

    # Test property setters.
    instance.uri = '/tmp/uri2'
    self.assertEqual('/tmp/uri2', instance.uri)

    instance.id = 1
    self.assertEqual(1, instance.id)

    instance.type_id = 2
    self.assertEqual(2, instance.type_id)

    instance.state = types.ARTIFACT_STATE_DELETED
    self.assertEqual(types.ARTIFACT_STATE_DELETED, instance.state)

    instance.split = ''
    self.assertEqual('', instance.split)

    instance.span = 20190101
    self.assertEqual(20190101, instance.span)

    instance.set_int_custom_property('int_key', 20)
    self.assertEqual(20,
                     instance.artifact.custom_properties['int_key'].int_value)

    instance.set_string_custom_property('string_key', 'string_value')
    self.assertEqual(
        'string_value',
        instance.artifact.custom_properties['string_key'].string_value)

    self.assertEqual('MyTypeName:/tmp/uri2.1', str(instance))

    # Test json serialization.
    json_dict = instance.json_dict()
    s = json.dumps(json_dict)
    other_instance = types.TfxType.parse_from_json_dict(json.loads(s))
    self.assertEqual(instance.artifact, other_instance.artifact)
    self.assertEqual(instance.artifact_type, other_instance.artifact_type)

    # Test pickling
    dumped_instance = pickle.dumps(instance)
    loaded_instance = pickle.loads(dumped_instance)
    self.assertEqual(instance.artifact, loaded_instance.artifact)
    self.assertEqual(instance.artifact_type, loaded_instance.artifact_type)

    self.assertIsNone(instance.source)
    instance.source = 'hello_world'
    self.assertEqual('hello_world', instance.source)
コード例 #30
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        train_examples = types.TfxType(type_name='ExamplesPath', split='train')
        train_examples.uri = os.path.join(output_data_dir, 'train')
        eval_examples = types.TfxType(type_name='ExamplesPath', split='eval')
        eval_examples.uri = os.path.join(output_data_dir, 'eval')
        output_dict = {'examples': [train_examples, eval_examples]}

        # Create exec proterties.
        exec_properties = {
            'input':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='tfrecord',
                                                pattern='tfrecord/*'),
                ])),
            'output':
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        import_example_gen = executor.Executor()
        import_example_gen.Do(self._input_dict, output_dict, exec_properties)

        # Check import_example_gen outputs.
        train_output_file = os.path.join(train_examples.uri,
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(eval_examples.uri,
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(tf.gfile.Exists(train_output_file))
        self.assertTrue(tf.gfile.Exists(eval_output_file))
        self.assertGreater(
            tf.gfile.GFile(train_output_file).size(),
            tf.gfile.GFile(eval_output_file).size())