コード例 #1
0
ファイル: executor_test.py プロジェクト: zzhmtxxhh/tfx
    def setUp(self):
        super(ExecutorTest, self).setUp()

        self.source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        self.train_stats_artifact = standard_artifacts.ExampleStatistics(
            split='train')
        self.train_stats_artifact.uri = os.path.join(self.source_data_dir,
                                                     'statistics_gen/train/')

        self.output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        self.schema_output = standard_artifacts.Schema()
        self.schema_output.uri = os.path.join(self.output_data_dir,
                                              'schema_output')

        self.schema = standard_artifacts.Schema()
        self.schema.uri = os.path.join(self.source_data_dir, 'fixed_schema/')

        self.expected_schema = standard_artifacts.Schema()
        self.expected_schema.uri = os.path.join(self.source_data_dir,
                                                'schema_gen/')

        self.input_dict = {
            'stats': [self.train_stats_artifact],
            'schema': None
        }
        self.output_dict = {
            'output': [self.schema_output],
        }
        self.exec_properties = {'infer_feature_shape': False}
コード例 #2
0
 def testInvalidOutput(self):
     with self.assertRaises(KeyError):
         _executor.Executor().Do({}, {}, self.exec_properties)
     with self.assertRaisesRegex(ValueError, 'expected list length of one'):
         _executor.Executor().Do({}, {
             standard_component_specs.SCHEMA_KEY:
             [standard_artifacts.Schema(),
              standard_artifacts.Schema()]
         }, self.exec_properties)
コード例 #3
0
 def testEnableCache(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['eval'])
     example_validator_1 = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
     )
     self.assertEqual(None, example_validator_1.enable_cache)
     example_validator_2 = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
         enable_cache=True)
     self.assertEqual(True, example_validator_2.enable_cache)
コード例 #4
0
  def testDo(self):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')

    statistics_artifact = standard_artifacts.ExampleStatistics()
    statistics_artifact.uri = os.path.join(source_data_dir, 'statistics_gen')
    statistics_artifact.split_names = artifact_utils.encode_split_names(
        ['train', 'eval', 'test'])

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    schema_output = standard_artifacts.Schema()
    schema_output.uri = os.path.join(output_data_dir, 'schema_output')

    input_dict = {
        standard_component_specs.STATISTICS_KEY: [statistics_artifact],
    }

    exec_properties = {
        # List needs to be serialized before being passed into Do function.
        standard_component_specs.EXCLUDE_SPLITS_KEY:
            json_utils.dumps(['test'])
    }

    output_dict = {
        standard_component_specs.SCHEMA_KEY: [schema_output],
    }

    schema_gen_executor = executor.Executor()
    schema_gen_executor.Do(input_dict, output_dict, exec_properties)
    self.assertNotEqual(0, len(fileio.listdir(schema_output.uri)))
コード例 #5
0
ファイル: component_test.py プロジェクト: etarakci-hvl/tfx
 def setUp(self):
     super(ComponentTest, self).setUp()
     self.input_data = channel_utils.as_channel([
         standard_artifacts.Examples(split='train'),
         standard_artifacts.Examples(split='eval'),
     ])
     self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
コード例 #6
0
ファイル: executor_test.py プロジェクト: meixinzhang/tfx
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            executor.EXAMPLES_KEY: [examples],
            executor.SCHEMA_KEY: [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_graph')
        self._transformed_examples = standard_artifacts.Examples()
        self._transformed_examples.uri = output_data_dir
        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()

        self._output_dict = {
            executor.TRANSFORM_GRAPH_KEY: [self._transformed_output],
            executor.TRANSFORMED_EXAMPLES_KEY: [self._transformed_examples],
            executor.TEMP_PATH_KEY: [temp_path_output],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
コード例 #7
0
    def testConstructFromModuleFile(self):
        examples = channel_utils.as_channel([standard_artifacts.Examples()])
        transform_graph = channel_utils.as_channel(
            [standard_artifacts.TransformGraph()])
        super().setUp()
        schema = channel_utils.as_channel([standard_artifacts.Schema()])
        train_args = trainer_pb2.TrainArgs(splits=['train'], num_steps=100)
        eval_args = trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)
        module_file = '/path/to/module/file'
        trainer = component.Trainer(module_file=module_file,
                                    transformed_examples=examples,
                                    transform_graph=transform_graph,
                                    schema=schema,
                                    train_args=train_args,
                                    eval_args=eval_args)

        self.assertEqual(
            standard_artifacts.Model.TYPE_NAME,
            trainer.outputs[standard_component_specs.MODEL_KEY].type_name)
        self.assertEqual(
            standard_artifacts.ModelRun.TYPE_NAME,
            trainer.outputs[standard_component_specs.MODEL_RUN_KEY].type_name)
        self.assertEqual(
            module_file, trainer.spec.exec_properties[
                standard_component_specs.MODULE_FILE_KEY])
コード例 #8
0
ファイル: executor_test.py プロジェクト: zw39125432/tfx
    def _make_base_do_params(self, source_data_dir, output_data_dir):
        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        self._input_dict = {
            'input_data': [examples],
            'schema': [schema_artifact],
        }

        # Create output dict.
        self._transformed_output = standard_artifacts.TransformGraph()
        self._transformed_output.uri = os.path.join(output_data_dir,
                                                    'transformed_output')
        self._transformed_examples = standard_artifacts.Examples()
        self._transformed_examples.uri = output_data_dir
        self._transformed_examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        temp_path_output = _TempPath()
        temp_path_output.uri = tempfile.mkdtemp()

        self._output_dict = {
            'transform_output': [self._transformed_output],
            'transformed_examples': [self._transformed_examples],
            'temp_path': [temp_path_output],
        }

        # Create exec properties skeleton.
        self._exec_properties = {}
コード例 #9
0
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._testdata_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    self._module_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'example')
    self._output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    self._context = executor.Executor.Context(
        tmp_dir=self._output_data_dir, unique_id='1')

    # Create input dict.
    examples = standard_artifacts.Examples()
    examples.uri = os.path.join(self._testdata_dir, 'data')
    examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
    schema = standard_artifacts.Schema()
    schema.uri = os.path.join(self._testdata_dir, 'schema')

    self._input_dict = {
        'examples': [examples],
        'schema': [schema],
    }

    # Create output dict.
    model = standard_artifacts.Model()
    model.uri = os.path.join(self._output_data_dir, 'model')
    self._best_hparams = standard_artifacts.Model()
    self._best_hparams.uri = os.path.join(self._output_data_dir, 'best_hparams')

    self._output_dict = {
        'model': [model],
        'best_hyperparameters': [self._best_hparams],
    }
コード例 #10
0
 def testConstruct(self):
   output_data_dir = os.path.join(
     os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName)
   examples_channel = channel_utils.as_channel([standard_artifacts.Examples()])
   schema_channel = channel_utils.as_channel([standard_artifacts.Schema()])
   transform_graph_channel = channel_utils.as_channel([standard_artifacts.TransformGraph()])
   component_instance = component.TransformWithGraph(
     examples=examples_channel,
     schema=schema_channel,
     transform_graph=transform_graph_channel,
   )
   self.assertEqual(
     'Schema',
     component_instance.inputs.schema.type_name
   )
   self.assertEqual(
     'Examples',
     component_instance.inputs.examples.type_name
   )
   self.assertEqual(
     'TransformGraph',
     component_instance.inputs.transform_graph.type_name
   )
   self.assertEqual(
     'Examples',
     component_instance.outputs.transformed_examples.type_name
   )
コード例 #11
0
  def __init__(self,
               stats: types.Channel = None,
               infer_feature_shape: bool = True,
               output: Optional[types.Channel] = None,
               statistics: Optional[types.Channel] = None,
               name: Optional[Text] = None):
    """Constructs a SchemaGen component.

    Args:
      stats: A Channel of 'ExampleStatisticsPath' type (required if spec is not
        passed). This should contain at least a 'train' split. Other splits are
        currently ignored (required).
      infer_feature_shape: bool value indicating whether or not to infer the
        shape of features. If feature shape is not inferred, downstream
        Tensorflow Transform component using the schema will parse input
        as tf.SparseTensor.
      output: Optional output 'SchemaPath' channel for schema result.
      statistics: Forwards compatibility alias for the 'stats' argument.
      name: Optional unique name. Necessary iff multiple SchemaGen components
        are declared in the same pipeline.
    """
    stats = stats or statistics
    output = output or types.Channel(
        type=standard_artifacts.Schema, artifacts=[standard_artifacts.Schema()])
    spec = SchemaGenSpec(
        stats=stats, infer_feature_shape=infer_feature_shape, output=output)
    super(SchemaGen, self).__init__(spec=spec, name=name)
コード例 #12
0
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        statistics_artifact = standard_artifacts.ExampleStatistics()
        statistics_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        statistics_artifact.split_names = artifact_utils.encode_split_names(
            ['train'])

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        schema_output = standard_artifacts.Schema()
        schema_output.uri = os.path.join(output_data_dir, 'schema_output')

        input_dict = {
            'stats': [statistics_artifact],
        }
        output_dict = {
            'output': [schema_output],
        }

        exec_properties = {'infer_feature_shape': False}

        schema_gen_executor = executor.Executor()
        schema_gen_executor.Do(input_dict, output_dict, exec_properties)
        self.assertNotEqual(0, len(tf.io.gfile.listdir(schema_output.uri)))
コード例 #13
0
ファイル: executor_test.py プロジェクト: suryaavala/tfx
    def testDo(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        eval_stats_artifact = standard_artifacts.ExampleStatistics()
        eval_stats_artifact.uri = os.path.join(source_data_dir,
                                               'statistics_gen')
        eval_stats_artifact.split_names = artifact_utils.encode_split_names(
            ['train', 'eval', 'test'])

        schema_artifact = standard_artifacts.Schema()
        schema_artifact.uri = os.path.join(source_data_dir, 'schema_gen')

        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        validation_output = standard_artifacts.ExampleAnomalies()
        validation_output.uri = os.path.join(output_data_dir, 'output')

        input_dict = {
            STATISTICS_KEY: [eval_stats_artifact],
            SCHEMA_KEY: [schema_artifact],
        }

        exec_properties = {
            # List needs to be serialized before being passed into Do function.
            EXCLUDE_SPLITS_KEY: json_utils.dumps(['test'])
        }

        output_dict = {
            ANOMALIES_KEY: [validation_output],
        }

        example_validator_executor = executor.Executor()
        example_validator_executor.Do(input_dict, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         validation_output.split_names)

        # Check example_validator outputs.
        train_anomalies_path = os.path.join(validation_output.uri,
                                            'Split-train', 'SchemaDiff.pb')
        eval_anomalies_path = os.path.join(validation_output.uri, 'Split-eval',
                                           'SchemaDiff.pb')
        self.assertTrue(fileio.exists(train_anomalies_path))
        self.assertTrue(fileio.exists(eval_anomalies_path))
        train_anomalies_bytes = io_utils.read_bytes_file(train_anomalies_path)
        train_anomalies = anomalies_pb2.Anomalies()
        train_anomalies.ParseFromString(train_anomalies_bytes)
        eval_anomalies_bytes = io_utils.read_bytes_file(eval_anomalies_path)
        eval_anomalies = anomalies_pb2.Anomalies()
        eval_anomalies.ParseFromString(eval_anomalies_bytes)
        self.assertEqual(0, len(train_anomalies.anomaly_info))
        self.assertEqual(0, len(eval_anomalies.anomaly_info))

        # Assert 'test' split is excluded.
        train_file_path = os.path.join(validation_output.uri, 'Split-test',
                                       'SchemaDiff.pb')
        self.assertFalse(fileio.exists(train_file_path))
コード例 #14
0
 def setUp(self):
     super(ComponentTest, self).setUp()
     examples_artifact = standard_artifacts.Examples()
     examples_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     self.examples = channel_utils.as_channel([examples_artifact])
     self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
コード例 #15
0
ファイル: executor_test.py プロジェクト: yongsheng268/tfx
  def test_do(self):
    source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')

    train_stats_artifact = types.Artifact('ExampleStatsPath', split='train')
    train_stats_artifact.uri = os.path.join(source_data_dir,
                                            'statistics_gen/train/')

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    schema_output = standard_artifacts.Schema()
    schema_output.uri = os.path.join(output_data_dir, 'schema_output')

    input_dict = {
        'stats': [train_stats_artifact],
    }
    output_dict = {
        'output': [schema_output],
    }

    exec_properties = {}

    schema_gen_executor = executor.Executor()
    schema_gen_executor.Do(input_dict, output_dict, exec_properties)
    self.assertNotEqual(0, len(tf.gfile.ListDirectory(schema_output.uri)))
コード例 #16
0
    def testDoValidation(self, exec_properties, blessed, has_baseline):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        model = standard_artifacts.Model()
        baseline_model = standard_artifacts.Model()
        model.uri = os.path.join(source_data_dir, 'trainer/current')
        baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')
        input_dict = {
            EXAMPLES_KEY: [examples],
            MODEL_KEY: [model],
            SCHEMA_KEY: [schema],
        }
        if has_baseline:
            input_dict[BASELINE_MODEL_KEY] = [baseline_model]

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        output_dict = {
            EVALUATION_KEY: [eval_output],
            BLESSING_KEY: [blessing_output],
        }

        # List needs to be serialized before being passed into Do function.
        exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(None)

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri,
                                                   'metrics')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'validations')))
        if blessed:
            self.assertTrue(
                fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
        else:
            self.assertTrue(
                fileio.exists(os.path.join(blessing_output.uri,
                                           'NOT_BLESSED')))
コード例 #17
0
ファイル: component.py プロジェクト: tvalentyn/tfx
  def __init__(self,
               statistics: Optional[types.Channel] = None,
               infer_feature_shape: Optional[bool] = False,
               output: Optional[types.Channel] = None,
               stats: Optional[types.Channel] = None,
               instance_name: Optional[Text] = None):
    """Constructs a SchemaGen component.

    Args:
      statistics: A Channel of `ExampleStatistics` type (required if spec is not
        passed). This should contain at least a `train` split. Other splits are
        currently ignored. _required_
      infer_feature_shape: Boolean value indicating whether or not to infer the
        shape of features. If the feature shape is not inferred, downstream
        Tensorflow Transform component using the schema will parse input
        as tf.SparseTensor.
      output: Output `Schema` channel for schema result.
      stats: Backwards compatibility alias for the 'statistics' argument.
      instance_name: Optional name assigned to this specific instance of
        SchemaGen.  Required only if multiple SchemaGen components are declared
        in the same pipeline.

      Either `statistics` or `stats` must be present in the input arguments.
    """
    statistics = statistics or stats
    output = output or types.Channel(
        type=standard_artifacts.Schema, artifacts=[standard_artifacts.Schema()])

    spec = SchemaGenSpec(
        stats=statistics,
        infer_feature_shape=infer_feature_shape,
        output=output)
    super(SchemaGen, self).__init__(spec=spec, instance_name=instance_name)
コード例 #18
0
 def testConstruct(self):
   example_validator = component.ExampleValidator(
       statistics=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='eval')]),
       schema=channel_utils.as_channel([standard_artifacts.Schema()]),
   )
   self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                    example_validator.outputs['anomalies'].type_name)
コード例 #19
0
 def setUp(self):
     super(ComponentTest, self).setUp()
     examples_artifact = standard_artifacts.Examples()
     examples_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     self.examples = channel_utils.as_channel([examples_artifact])
     self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
     self.custom_config = {'some': 'thing', 'some other': 1, 'thing': 2}
コード例 #20
0
ファイル: component_test.py プロジェクト: reddqian/tfx
 def testConstruct(self):
   example_validator = component.ExampleValidator(
       stats=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='eval')]),
       schema=channel_utils.as_channel([standard_artifacts.Schema()]),
   )
   self.assertEqual('ExampleValidationPath',
                    example_validator.outputs['output'].type_name)
コード例 #21
0
ファイル: fn_args_utils_test.py プロジェクト: jay90099/tfx
    def testGetCommonFnArgs(self):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir,
                                    'transform/transformed_examples')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        transform_output = standard_artifacts.TransformGraph()
        transform_output.uri = os.path.join(source_data_dir,
                                            'transform/transform_graph')

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')

        base_model = standard_artifacts.Model()
        base_model.uri = os.path.join(source_data_dir, 'trainer/previous')

        input_dict = {
            standard_component_specs.EXAMPLES_KEY: [examples],
            standard_component_specs.TRANSFORM_GRAPH_KEY: [transform_output],
            standard_component_specs.SCHEMA_KEY: [schema],
            standard_component_specs.BASE_MODEL_KEY: [base_model],
        }

        # Create exec properties skeleton.
        exec_properties = {
            'train_args':
            proto_utils.proto_to_json(trainer_pb2.TrainArgs(num_steps=1000)),
            'eval_args':
            proto_utils.proto_to_json(trainer_pb2.EvalArgs(num_steps=500)),
        }

        fn_args = fn_args_utils.get_common_fn_args(input_dict, exec_properties,
                                                   'tempdir')
        self.assertEqual(fn_args.working_dir, 'tempdir')
        self.assertEqual(fn_args.train_steps, 1000)
        self.assertEqual(fn_args.eval_steps, 500)
        self.assertLen(fn_args.train_files, 1)
        self.assertEqual(fn_args.train_files[0],
                         os.path.join(examples.uri, 'Split-train', '*'))
        self.assertLen(fn_args.eval_files, 1)
        self.assertEqual(fn_args.eval_files[0],
                         os.path.join(examples.uri, 'Split-eval', '*'))
        self.assertEqual(fn_args.schema_path,
                         os.path.join(schema.uri, 'schema.pbtxt'))
        # Depending on execution environment, the base model may have been stored
        # at .../Format-Servo/... or .../Format-Serving/... directory patterns.
        self.assertRegex(
            fn_args.base_model,
            os.path.join(base_model.uri,
                         r'Format-(Servo|Serving)/export/chicago-taxi/\d+'))
        self.assertEqual(fn_args.transform_graph_path, transform_output.uri)
        self.assertIsInstance(fn_args.data_accessor,
                              fn_args_utils.DataAccessor)
コード例 #22
0
    def setUp(self):
        super(ExecutorTest, self).setUp()

        self._testdata_dir = os.path.join(os.path.dirname(__file__),
                                          'testdata')

        self._output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                           tempfile.mkdtemp(dir=flags.FLAGS.test_tmpdir)),
            self._testMethodName)

        self._context = executor.Executor.Context(
            tmp_dir=self._output_data_dir, unique_id='1')

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(self._testdata_dir, 'Transform.mockdata_1',
                                    'transformed_examples', '10')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])

        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._testdata_dir, 'SchemaGen.mockdata_1',
                                  'schema', '1')

        transform_output = standard_artifacts.TransformGraph()
        transform_output.uri = os.path.join(self._testdata_dir,
                                            'Transform.mockdata_1',
                                            'transform_graph', '10')

        self._input_dict = {
            'examples': [examples],
            'schema': [schema],
            'transform_graph': [transform_output],
        }

        # Create output dict.
        self._best_hparams = standard_artifacts.Model()
        self._best_hparams.uri = os.path.join(self._output_data_dir,
                                              'best_hparams')

        self._tuner_data = tuner_component.TunerData()
        self._tuner_data.uri = os.path.join(self._output_data_dir,
                                            'trial_summary_plot')
        self._output_dict = {
            'best_hyperparameters': [self._best_hparams],
            'trial_summary_plot': [self._tuner_data],
        }

        # Create exec properties.
        self._exec_properties = {
            'train_args':
            json_format.MessageToJson(trainer_pb2.TrainArgs(num_steps=2),
                                      preserving_proto_field_name=True),
            'eval_args':
            json_format.MessageToJson(trainer_pb2.EvalArgs(num_steps=1),
                                      preserving_proto_field_name=True),
        }
コード例 #23
0
ファイル: component_test.py プロジェクト: romeokienzler/tfx
  def setUp(self):
    super(ComponentTest, self).setUp()

    self.examples = channel_utils.as_channel([standard_artifacts.Examples()])
    self.transform_output = channel_utils.as_channel(
        [standard_artifacts.TransformResult()])
    self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
    self.train_args = trainer_pb2.TrainArgs(num_steps=100)
    self.eval_args = trainer_pb2.EvalArgs(num_steps=50)
コード例 #24
0
    def setUp(self):
        super(KubeflowGCPIntegrationTest, self).setUp()

        # Example artifacts for testing.
        raw_train_examples = standard_artifacts.Examples(split='train')
        raw_train_examples.uri = os.path.join(
            self._intermediate_data_root,
            'csv_example_gen/examples/test-pipeline/train/')
        raw_eval_examples = standard_artifacts.Examples(split='eval')
        raw_eval_examples.uri = os.path.join(
            self._intermediate_data_root,
            'csv_example_gen/examples/test-pipeline/eval/')
        self._test_raw_examples = [raw_train_examples, raw_eval_examples]

        # Transformed Example artifacts for testing.
        transformed_train_examples = standard_artifacts.Examples(split='train')
        transformed_train_examples.uri = os.path.join(
            self._intermediate_data_root,
            'transform/transformed_examples/test-pipeline/train/')
        transformed_eval_examples = standard_artifacts.Examples(split='eval')
        transformed_eval_examples.uri = os.path.join(
            self._intermediate_data_root,
            'transform/transformed_examples/test-pipeline/eval/')
        self._test_transformed_examples = [
            transformed_train_examples, transformed_eval_examples
        ]

        # Schema artifact for testing.
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(self._intermediate_data_root,
                                  'schema_gen/output/test-pipeline/')
        self._test_schema = [schema]

        # TransformGraph artifact for testing.
        transform_graph = standard_artifacts.TransformGraph()
        transform_graph.uri = os.path.join(
            self._intermediate_data_root,
            'transform/transform_output/test-pipeline/')
        self._test_transform_graph = [transform_graph]

        # Model artifact for testing.
        model_1 = standard_artifacts.Model()
        model_1.uri = os.path.join(self._intermediate_data_root,
                                   'trainer/output/test-pipeline/1/')
        self._test_model_1 = [model_1]

        model_2 = standard_artifacts.Model()
        model_2.uri = os.path.join(self._intermediate_data_root,
                                   'trainer/output/test-pipeline/2/')
        self._test_model_2 = [model_2]

        # ModelBlessing artifact for testing.
        model_blessing = standard_artifacts.ModelBlessing()
        model_blessing.uri = os.path.join(
            self._intermediate_data_root,
            'model_validator/blessing/test-pipeline/')
        self._test_model_blessing = [model_blessing]
コード例 #25
0
ファイル: component_test.py プロジェクト: yifanmai/tfx
 def setUp(self):
   super(TunerTest, self).setUp()
   self.examples = channel_utils.as_channel([standard_artifacts.Examples()])
   self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
   self.transform_graph = channel_utils.as_channel(
       [standard_artifacts.TransformGraph()])
   self.train_args = trainer_pb2.TrainArgs(splits=['train'], num_steps=100)
   self.eval_args = trainer_pb2.EvalArgs(splits=['eval'], num_steps=50)
   self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3)
コード例 #26
0
ファイル: executor_test.py プロジェクト: nex3z/tfx
  def setUp(self):
    super(ExecutorTest, self).setUp()
    self._source_data_dir = os.path.join(
        os.path.dirname(os.path.dirname(__file__)), 'testdata')
    self._output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create input dict.
    examples = standard_artifacts.Examples()
    examples.uri = os.path.join(self._source_data_dir,
                                'transform/transformed_examples')
    examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
    transform_output = standard_artifacts.TransformGraph()
    transform_output.uri = os.path.join(self._source_data_dir,
                                        'transform/transform_output')
    schema = standard_artifacts.Schema()
    schema.uri = os.path.join(self._source_data_dir, 'schema_gen')
    previous_model = standard_artifacts.Model()
    previous_model.uri = os.path.join(self._source_data_dir, 'trainer/previous')

    self._input_dict = {
        executor.EXAMPLES_KEY: [examples],
        executor.TRANSFORM_GRAPH_KEY: [transform_output],
        executor.SCHEMA_KEY: [schema],
        executor.BASE_MODEL_KEY: [previous_model]
    }

    # Create output dict.
    self._model_exports = standard_artifacts.Model()
    self._model_exports.uri = os.path.join(self._output_data_dir,
                                           'model_export_path')
    self._output_dict = {executor.OUTPUT_MODEL_KEY: [self._model_exports]}

    # Create exec properties skeleton.
    self._exec_properties = {
        'train_args':
            json_format.MessageToJson(
                trainer_pb2.TrainArgs(num_steps=1000),
                preserving_proto_field_name=True),
        'eval_args':
            json_format.MessageToJson(
                trainer_pb2.EvalArgs(num_steps=500),
                preserving_proto_field_name=True),
        'warm_starting':
            False,
    }

    self._module_file = os.path.join(self._source_data_dir, 'module_file',
                                     'trainer_module.py')
    self._trainer_fn = '%s.%s' % (trainer_module.trainer_fn.__module__,
                                  trainer_module.trainer_fn.__name__)

    # Executors for test.
    self._trainer_executor = executor.Executor()
    self._generic_trainer_executor = executor.GenericExecutor()
コード例 #27
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['eval'])
     example_validator = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
     )
     self.assertEqual(standard_artifacts.ExampleAnomalies.TYPE_NAME,
                      example_validator.outputs['anomalies'].type_name)
コード例 #28
0
    def testEvalution(self, exec_properties, model_agnostic=False):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        baseline_model = standard_artifacts.Model()
        baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/')
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')
        input_dict = {
            EXAMPLES_KEY: [examples],
            SCHEMA_KEY: [schema],
        }
        if not model_agnostic:
            model = standard_artifacts.Model()
            model.uri = os.path.join(source_data_dir, 'trainer/current')
            input_dict[MODEL_KEY] = [model]

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        output_dict = {
            EVALUATION_KEY: [eval_output],
            BLESSING_KEY: [blessing_output],
        }

        # Test multiple splits.
        exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(
            ['train', 'eval'])

        if MODULE_FILE_KEY in exec_properties:
            exec_properties[MODULE_FILE_KEY] = os.path.join(
                source_data_dir, 'module_file', 'evaluator_module.py')

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri,
                                                   'metrics')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
        self.assertFalse(
            fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
コード例 #29
0
ファイル: executor_utils_test.py プロジェクト: jay90099/tfx
  def testGetStatusOutputPathsEntries(self):
    # disabled.
    self.assertEmpty(executor_utils.GetStatsOutputPathEntries(True, {}))

    # enabled.
    pre_transform_stats = standard_artifacts.ExampleStatistics()
    pre_transform_stats.uri = '/pre_transform_stats'
    pre_transform_schema = standard_artifacts.Schema()
    pre_transform_schema.uri = '/pre_transform_schema'
    post_transform_anomalies = standard_artifacts.ExampleAnomalies()
    post_transform_anomalies.uri = '/post_transform_anomalies'
    post_transform_stats = standard_artifacts.ExampleStatistics()
    post_transform_stats.uri = '/post_transform_stats'
    post_transform_schema = standard_artifacts.Schema()
    post_transform_schema.uri = '/post_transform_schema'

    result = executor_utils.GetStatsOutputPathEntries(
        False, {
            standard_component_specs.PRE_TRANSFORM_STATS_KEY:
                [pre_transform_stats],
            standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY:
                [pre_transform_schema],
            standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY:
                [post_transform_anomalies],
            standard_component_specs.POST_TRANSFORM_STATS_KEY:
                [post_transform_stats],
            standard_component_specs.POST_TRANSFORM_SCHEMA_KEY:
                [post_transform_schema],
        })
    self.assertEqual(
        {
            labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL:
                '/pre_transform_stats',
            labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL:
                '/pre_transform_schema',
            labels.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL:
                '/post_transform_anomalies',
            labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL:
                '/post_transform_stats',
            labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL:
                '/post_transform_schema',
        }, result)
コード例 #30
0
ファイル: component_test.py プロジェクト: jay90099/tfx
 def setUp(self):
     super().setUp()
     self.examples = channel_utils.as_channel(
         [standard_artifacts.Examples()])
     self.schema = channel_utils.as_channel([standard_artifacts.Schema()])
     self.transform_graph = channel_utils.as_channel(
         [standard_artifacts.TransformGraph()])
     self.train_args = trainer_pb2.TrainArgs(num_steps=100)
     self.eval_args = trainer_pb2.EvalArgs(num_steps=50)
     self.tune_args = tuner_pb2.TuneArgs(num_parallel_trials=3)
     self.custom_config = {'key': 'value'}