Example #1
0
  def testPrepareOutputArtifacts(self):
    examples = standard_artifacts.Examples()
    output_dict = {
        standard_component_specs.EXAMPLES_KEY:
            channel_utils.as_channel([examples])
    }
    exec_properties = {
        utils.SPAN_PROPERTY_NAME: 2,
        utils.VERSION_PROPERTY_NAME: 1,
        utils.FINGERPRINT_PROPERTY_NAME: 'fp'
    }

    pipeline_info = data_types.PipelineInfo(
        pipeline_name='name', pipeline_root=self._test_dir, run_id='rid')
    component_info = data_types.ComponentInfo(
        component_type='type', component_id='cid', pipeline_info=pipeline_info)

    input_artifacts = {}
    output_artifacts = self._file_based_driver._prepare_output_artifacts(  # pylint: disable=protected-access
        input_artifacts, output_dict, exec_properties, 1, pipeline_info,
        component_info)
    examples = artifact_utils.get_single_instance(
        output_artifacts[standard_component_specs.EXAMPLES_KEY])
    base_output_dir = os.path.join(self._test_dir, component_info.component_id)
    expected_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
        base_output_dir, 'examples', 1)

    self.assertEqual(examples.uri, expected_uri)
    self.assertEqual(
        examples.get_string_custom_property(utils.FINGERPRINT_PROPERTY_NAME),
        'fp')
    self.assertEqual(
        examples.get_int_custom_property(utils.SPAN_PROPERTY_NAME), 2)
    self.assertEqual(
        examples.get_int_custom_property(utils.VERSION_PROPERTY_NAME), 1)
Example #2
0
 def test_channel_utils_as_channel_success(self):
     instance_a = Artifact('MyTypeName')
     instance_b = Artifact('MyTypeName')
     chnl_original = Channel('MyTypeName',
                             artifacts=[instance_a, instance_b])
     chnl_result = channel_utils.as_channel(chnl_original)
     self.assertEqual(chnl_original, chnl_result)
Example #3
0
 def testConstruct(self):
     statistics_artifact = standard_artifacts.ExampleStatistics()
     statistics_artifact.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     exclude_splits = ['eval']
     example_validator = component.ExampleValidator(
         statistics=channel_utils.as_channel([statistics_artifact]),
         schema=channel_utils.as_channel([standard_artifacts.Schema()]),
         exclude_splits=exclude_splits)
     self.assertEqual(
         standard_artifacts.ExampleAnomalies.TYPE_NAME,
         example_validator.outputs[
             standard_component_specs.ANOMALIES_KEY].type_name)
     self.assertEqual(
         example_validator.spec.exec_properties[
             standard_component_specs.EXCLUDE_SPLITS_KEY], '["eval"]')
Example #4
0
    def testConstructInferenceResultAndOutputExample(self):
        with self.assertRaises(ValueError):
            component.BulkInferrer(examples=self._examples,
                                   model=self._model,
                                   model_blessing=self._model_blessing,
                                   output_examples=channel_utils.as_channel(
                                       [standard_artifacts.Examples()]))

        with self.assertRaises(ValueError):
            component.BulkInferrer(
                examples=self._examples,
                model=self._model,
                model_blessing=self._model_blessing,
                output_example_spec=bulk_inferrer_pb2.OutputExampleSpec(),
                inference_result=channel_utils.as_channel(
                    [standard_artifacts.InferenceResult()]))
Example #5
0
    def testPrepareOutputArtifacts(self):
        examples = standard_artifacts.Examples()
        output_dict = {
            utils.EXAMPLES_KEY: channel_utils.as_channel([examples])
        }
        exec_properties = {
            utils.SPAN_PROPERTY_NAME: '02',
            utils.FINGERPRINT_PROPERTY_NAME: 'fp'
        }

        pipeline_info = data_types.PipelineInfo(pipeline_name='name',
                                                pipeline_root=self._test_dir,
                                                run_id='rid')
        component_info = data_types.ComponentInfo(component_type='type',
                                                  component_id='cid',
                                                  pipeline_info=pipeline_info)

        output_artifacts = self._example_gen_driver._prepare_output_artifacts(
            output_dict, exec_properties, 1, pipeline_info, component_info)
        examples = artifact_utils.get_single_instance(
            output_artifacts[utils.EXAMPLES_KEY])
        self.assertEqual(examples.uri,
                         os.path.join(self._test_dir, 'cid', 'examples', '1'))
        self.assertEqual(
            examples.get_string_custom_property(
                utils.FINGERPRINT_PROPERTY_NAME), 'fp')
        self.assertEqual(
            examples.get_string_custom_property(utils.SPAN_PROPERTY_NAME),
            '02')
Example #6
0
 def __init__(self, *unused_args, **kwargs):
   if unused_args:
     raise ValueError(('%s expects arguments to be passed as keyword '
                       'arguments') % (self.__class__.__name__,))
   spec_kwargs = {}
   unseen_args = set(kwargs.keys())
   for key, channel_parameter in self.SPEC_CLASS.INPUTS.items():
     if key not in kwargs and not channel_parameter.optional:
       raise ValueError('%s expects input %r to be a Channel of type %s.' %
                        (self.__class__.__name__, key, channel_parameter.type))
     if key in kwargs:
       spec_kwargs[key] = kwargs[key]
       unseen_args.remove(key)
   for key, parameter in self.SPEC_CLASS.PARAMETERS.items():
     if key not in kwargs and not parameter.optional:
       raise ValueError('%s expects parameter %r of type %s.' %
                        (self.__class__.__name__, key, parameter.type))
     if key in kwargs:
       spec_kwargs[key] = kwargs[key]
       unseen_args.remove(key)
   instance_name = kwargs.get('instance_name', None)
   unseen_args.discard('instance_name')
   if unseen_args:
     raise ValueError(
         'Unknown arguments to %r: %s.' %
         (self.__class__.__name__, ', '.join(sorted(unseen_args))))
   for key, channel_parameter in self.SPEC_CLASS.OUTPUTS.items():
     spec_kwargs[key] = channel_utils.as_channel([channel_parameter.type()])
   spec = self.SPEC_CLASS(**spec_kwargs)
   super(_SimpleComponent, self).__init__(spec, instance_name=instance_name)
Example #7
0
    def testConstructNoPipelineConfiguration(self):
        examples = standard_artifacts.Examples()
        predicate_fn = """
  def predicate(m):
    return m.features.feature[key].float_list.value[0] > 0.5
"""

        filter = Filter(pipeline_configuration=None,
                        examples=channel_utils.as_channel([examples]),
                        filtered_examples=channel_utils.as_channel(
                            [standard_artifacts.Examples()]),
                        splits_to_transform=['eval'],
                        splits_to_copy=['train'],
                        predicate_fn=predicate_fn)
        self.assertEqual('Examples',
                         filter.outputs[FILTERED_EXAMPLES_KEY].type_name)
Example #8
0
 def testConstruct(self):
     input_data = standard_artifacts.Examples()
     input_data.split_names = json.dumps(artifact.DEFAULT_EXAMPLE_SPLITS)
     output_data = standard_artifacts.Examples()
     output_data.split_names = json.dumps(artifact.DEFAULT_EXAMPLE_SPLITS)
     this_component = component.HelloComponent(
         input_data=channel_utils.as_channel([input_data]),
         output_data=channel_utils.as_channel([output_data]),
         name=u'Testing123')
     self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
                      this_component.outputs['output_data'].type_name)
     artifact_collection = this_component.outputs['output_data'].get()
     for artifacts in artifact_collection:
         split_list = json.loads(artifacts.split_names)
         self.assertEqual(artifact.DEFAULT_EXAMPLE_SPLITS.sort(),
                          split_list.sort())
Example #9
0
    def __init__(self,
                 function_name: Text = None,
                 model: types.Channel = None,
                 model_blessing: Optional[types.Channel] = None,
                 infra_blessing: Optional[types.Channel] = None,
                 pushed_model: Optional[types.Channel] = None,
                 output: types.Channel = None,
                 pipeline_configuration: Optional[types.Channel] = None,
                 transform_graph: Optional[types.Channel] = None):
        """Construct a model export component.

    Args:
      function_name: The instance_name of the function to apply on the model.
      model: A Channel of type `standard_artifacts.Model`.
      model_blessing: A Channel of type `standard_artifacts.ModelBlessing`.
      infra_blessing: A Channel of type `standard_artifacts.InfraBlessing`.
      pushed_model: A Channel of type `standard_artifacts.PushedModel`.
      output: A Channel of type `ExportedModel`.
      pipeline_configuration: A Channel of 'PipelineConfiguration' type, usually produced by FromCustomConfig component.
      transform_graph: A channel of type `standard_artifacts.TransformGraph`.
    """

        if not output:
            output = channel_utils.as_channel([ExportedModel()])

        spec = ExportSpec(function_name=function_name,
                          pipeline_configuration=pipeline_configuration,
                          model=model,
                          model_blessing=model_blessing,
                          infra_blessing=infra_blessing,
                          pushed_model=pushed_model,
                          output=output,
                          transform_graph=transform_graph)
        super(Export, self).__init__(spec=spec)
Example #10
0
    def __init__(self,
                 input_data: types.Channel = None,
                 output_data: types.Channel = None,
                 name: Optional[Text] = None):
        """Construct a HelloComponent.

    Args:
      input_data: A Channel of type `standard_artifacts.Examples`. This will
        often contain two splits: 'train', and 'eval'.
      output_data: A Channel of type `standard_artifacts.Examples`. This will
        usually contain the same splits as input_data.
      name: Optional unique name. Necessary if multiple Hello components are
        declared in the same pipeline.
    """
        # output_data will contain a list of Channels for each split of the data,
        # by default a 'train' split and an 'eval' split. Since HelloComponent
        # passes the input data through to output, the splits in output_data will
        # be the same as the splits in input_data, which were generated by the
        # upstream component.
        if not output_data:
            output_data = channel_utils.as_channel(
                [standard_artifacts.Examples()])

        spec = HelloComponentSpec(input_data=input_data,
                                  output_data=output_data,
                                  name=name)
        super(HelloComponent, self).__init__(spec=spec)
Example #11
0
 def testConstructWithFairnessThresholds(self):
     examples = standard_artifacts.Examples()
     model_exports = standard_artifacts.Model()
     evaluator = component.Evaluator(
         examples=channel_utils.as_channel([examples]),
         model=channel_utils.as_channel([model_exports]),
         feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[
             evaluator_pb2.SingleSlicingSpec(
                 column_for_slicing=['trip_start_hour'])
         ]),
         fairness_indicator_thresholds=[0.1, 0.3, 0.5, 0.9])
     self.assertEqual(standard_artifacts.ModelEvaluation.TYPE_NAME,
                      evaluator.outputs['evaluation'].type_name)
     self.assertEqual(
         '[0.1, 0.3, 0.5, 0.9]',
         evaluator.exec_properties['fairness_indicator_thresholds'])
Example #12
0
    def testFindComponentLaunchInfoReturnConfigOverride(self):
        input_artifact = test_utils._InputArtifact()
        component = test_utils._FakeComponent(
            name='FakeComponent',
            input_channel=channel_utils.as_channel([input_artifact]),
            custom_executor_spec=executor_spec.ExecutorContainerSpec(
                image='gcr://test', args=['{{input_dict["input"][0].uri}}']))
        default_config = docker_component_config.DockerComponentConfig()
        override_config = docker_component_config.DockerComponentConfig(
            name='test')
        p_config = pipeline_config.PipelineConfig(
            supported_launcher_classes=[
                docker_component_launcher.DockerComponentLauncher
            ],
            default_component_configs=[default_config],
            component_config_overrides={
                '_FakeComponent.FakeComponent': override_config
            })

        (launcher_class, c_config) = config_utils.find_component_launch_info(
            p_config, component)

        self.assertEqual(docker_component_launcher.DockerComponentLauncher,
                         launcher_class)
        self.assertEqual(override_config, c_config)
Example #13
0
 def testConstruct(self):
   schema_gen = component.SchemaGen(
       stats=channel_utils.as_channel(
           [standard_artifacts.ExampleStatistics(split='train')]),
       infer_feature_shape=True)
   self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name)
   self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
Example #14
0
 def testConstruct(self):
   train_examples = standard_artifacts.Examples(split='train')
   eval_examples = standard_artifacts.Examples(split='eval')
   statistics_gen = component.StatisticsGen(
       examples=channel_utils.as_channel([train_examples, eval_examples]))
   self.assertEqual('ExampleStatisticsPath',
                    statistics_gen.outputs['statistics'].type_name)
Example #15
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input_base=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        input_data=example_gen.outputs.examples)

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name='test_pipeline',
        pipeline_root='test_pipeline_root',
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):

      self.component = base_component.BaseComponent(
          component=statistics_gen,
          depends_on=set([example_gen]),
          pipeline=pipeline,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
      )
Example #16
0
 def testArtifactCollectionAsChannel(self):
     instance_a = _MyArtifact()
     instance_b = _MyArtifact()
     chnl = channel_utils.as_channel([instance_a, instance_b])
     self.assertEqual(chnl.type, _MyArtifact)
     self.assertEqual(chnl.type_name, 'MyTypeName')
     self.assertCountEqual(chnl.get(), [instance_a, instance_b])
Example #17
0
  def setUp(self):
    super(BaseComponentWithPipelineParamTest, self).setUp()

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')
    example_gen_buckets = data_types.RuntimeParameter(
        name='example-gen-buckets', ptype=int, default=10)

    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]),
        output_config={
            'split_config': {
                'splits': [{
                    'name': 'examples',
                    'hash_buckets': example_gen_buckets
                }]
            }
        })
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    self._tfx_ir = pipeline_pb2.Pipeline()
    with dsl.Pipeline('test_pipeline'):
      self.example_gen = base_component.BaseComponent(
          component=example_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir)
      self.statistics_gen = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
          tfx_ir=self._tfx_ir
      )

    self.tfx_example_gen = example_gen
    self.tfx_statistics_gen = statistics_gen
Example #18
0
    def setUp(self):
        super(DriverTest, self).setUp()
        # Create input splits.
        test_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)
        self._input_base_path = os.path.join(test_dir, 'input_base')
        tf.io.gfile.makedirs(self._input_base_path)

        # Mock metadata.
        self._mock_metadata = tf.compat.v1.test.mock.Mock()
        self._example_gen_driver = driver.Driver(self._mock_metadata)

        # Create input dict.
        input_base = standard_artifacts.ExternalArtifact()
        input_base.uri = self._input_base_path
        self._input_channels = {
            'input_base': channel_utils.as_channel([input_base])
        }
        # Create exec proterties.
        self._exec_properties = {
            'input_config':
            json_format.MessageToJson(example_gen_pb2.Input(splits=[
                example_gen_pb2.Input.Split(name='s1',
                                            pattern='span{SPAN}/split1/*'),
                example_gen_pb2.Input.Split(name='s2',
                                            pattern='span{SPAN}/split2/*')
            ]),
                                      preserving_proto_field_name=True),
        }
Example #19
0
  def setUp(self):
    super(BaseComponentTest, self).setUp()
    examples = standard_artifacts.ExternalArtifact()
    example_gen = csv_example_gen_component.CsvExampleGen(
        input=channel_utils.as_channel([examples]))
    statistics_gen = statistics_gen_component.StatisticsGen(
        examples=example_gen.outputs['examples'], instance_name='foo')

    pipeline = tfx_pipeline.Pipeline(
        pipeline_name=self._test_pipeline_name,
        pipeline_root='test_pipeline_root',
        metadata_connection_config=metadata_store_pb2.ConnectionConfig(),
        components=[example_gen, statistics_gen],
    )

    test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param')

    self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig()
    self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST'
    with dsl.Pipeline('test_pipeline'):
      self.component = base_component.BaseComponent(
          component=statistics_gen,
          component_launcher_class=in_process_component_launcher
          .InProcessComponentLauncher,
          depends_on=set(),
          pipeline=pipeline,
          pipeline_name=self._test_pipeline_name,
          pipeline_root=test_pipeline_root,
          tfx_image='container_image',
          kubeflow_metadata_config=self._metadata_config,
          component_config=None,
      )
    self.tfx_component = statistics_gen
Example #20
0
    def pre_execution(
        self,
        input_dict: Dict[Text, types.Channel],
        output_dict: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
        component_info: data_types.ComponentInfo,
    ) -> data_types.ExecutionDecision:
        output_artifacts = {
            IMPORT_RESULT_KEY:
            self._import_artifacts(
                source_uri=exec_properties[SOURCE_URI_KEY],
                destination_channel=output_dict[IMPORT_RESULT_KEY],
                reimport=exec_properties[REIMPORT_OPTION_KEY],
                split_names=exec_properties[SPLIT_KEY])
        }

        output_dict[IMPORT_RESULT_KEY] = channel_utils.as_channel(
            output_artifacts[IMPORT_RESULT_KEY])

        return data_types.ExecutionDecision(
            input_dict={},
            output_dict=output_artifacts,
            exec_properties={},
            execution_id=self._register_execution(
                exec_properties={},
                pipeline_info=pipeline_info,
                component_info=component_info),
            use_cached_results=False)
Example #21
0
 def __init__(self, *unused_args, **kwargs):
     if unused_args:
         raise ValueError(('%s expects arguments to be passed as keyword '
                           'arguments') % (self.__class__.__name__, ))
     spec_kwargs = {}
     unseen_args = set(kwargs.keys())
     for key, channel_parameter in self.SPEC_CLASS.INPUTS.items():
         if key not in kwargs and not channel_parameter.optional:
             raise ValueError(
                 '%s expects input %r to be a Channel of type %s.' %
                 (self.__class__.__name__, key, channel_parameter.type))
         if key in kwargs:
             spec_kwargs[key] = kwargs[key]
             unseen_args.remove(key)
     for key, parameter in self.SPEC_CLASS.PARAMETERS.items():
         if key not in kwargs and not parameter.optional:
             raise ValueError(
                 '%s expects parameter %r of type %s.' %
                 (self.__class__.__name__, key, parameter.type))
         if key in kwargs:
             spec_kwargs[key] = kwargs[key]
             unseen_args.remove(key)
     if unseen_args:
         raise ValueError(
             'Unknown arguments to %r: %s.' %
             (self.__class__.__name__, ', '.join(sorted(unseen_args))))
     for key, channel_parameter in self.SPEC_CLASS.OUTPUTS.items():
         spec_kwargs[key] = channel_utils.as_channel(
             [channel_parameter.type()])
     spec = self.SPEC_CLASS(**spec_kwargs)
     super().__init__(spec)
     # Set class name, which is the decorated function name, as the default id.
     # It can be overwritten by the user.
     self._id = self.__class__.__name__
Example #22
0
    def _create_launcher_context(self, component_config=None):
        test_dir = self.get_temp_dir()

        connection_config = metadata_store_pb2.ConnectionConfig()
        connection_config.sqlite.SetInParent()
        metadata_connection = metadata.Metadata(connection_config)

        pipeline_root = os.path.join(test_dir, 'Test')

        input_artifact = test_utils._InputArtifact()
        input_artifact.uri = os.path.join(test_dir, 'input')

        component = test_utils._FakeComponent(
            name='FakeComponent',
            input_channel=channel_utils.as_channel([input_artifact]),
            custom_executor_spec=executor_spec.ExecutorContainerSpec(
                image='gcr://test', args=['{{input_dict["input"][0].uri}}']))

        pipeline_info = data_types.PipelineInfo(pipeline_name='Test',
                                                pipeline_root=pipeline_root,
                                                run_id='123')

        driver_args = data_types.DriverArgs(enable_cache=True)

        launcher = kubernetes_component_launcher.KubernetesComponentLauncher.create(
            component=component,
            pipeline_info=pipeline_info,
            driver_args=driver_args,
            metadata_connection=metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={},
            component_config=component_config)

        return {'launcher': launcher, 'input_artifact': input_artifact}
Example #23
0
 def testConstruct(self):
   train_examples_in = standard_artifacts.Examples(split='train')
   eval_examples_in = standard_artifacts.Examples(split='eval')
   train_examples_out = standard_artifacts.Examples(split='train')
   eval_examples_out = standard_artifacts.Examples(split='eval')
   this_component = component.HelloComponent(
       input_data=channel_utils.as_channel(
           [train_examples_in, eval_examples_in]),
       output_data=channel_utils.as_channel(
           [train_examples_out, eval_examples_out]),
       name=u'Testing123')
   self.assertEqual('ExamplesPath',
                    this_component.outputs['output_data'].type_name)
   artifact_collection = this_component.outputs['output_data'].get()
   self.assertEqual('train', artifact_collection[0].split)
   self.assertEqual('eval', artifact_collection[1].split)
Example #24
0
 def test_construct_with_cache_disabled_but_input_cache(self):
     with self.assertRaises(ValueError):
         _ = component.Transform(examples=self.examples,
                                 schema=self.schema,
                                 preprocessing_fn='my_preprocessing_fn',
                                 disable_analyzer_cache=True,
                                 analyzer_cache=channel_utils.as_channel(
                                     [standard_artifacts.TransformCache()]))
Example #25
0
 def testEnableCache(self):
     input_base = standard_artifacts.ExternalArtifact()
     custom_config = example_gen_pb2.CustomConfig(
         custom_config=any_pb2.Any())
     example_gen_1 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor))
     self.assertEqual(None, example_gen_1.enable_cache)
     example_gen_2 = component.FileBasedExampleGen(
         input=channel_utils.as_channel([input_base]),
         custom_config=custom_config,
         custom_executor_spec=executor_spec.ExecutorClassSpec(
             TestExampleGenExecutor),
         enable_cache=True)
     self.assertEqual(True, example_gen_2.enable_cache)
Example #26
0
 def testInvalidBenchmarkNameThrows(self):
     with self.assertRaises(ValueError):
         BenchmarkResultPublisher('',
                                  channel_utils.as_channel([
                                      standard_artifacts.ModelEvaluation()
                                  ]),
                                  run=1,
                                  num_runs=2)
Example #27
0
 def testConstruct(self):
     examples = standard_artifacts.Examples()
     examples.split_names = artifact_utils.encode_split_names(
         ['train', 'eval'])
     statistics_gen = component.StatisticsGen(
         examples=channel_utils.as_channel([examples]))
     self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
                      statistics_gen.outputs['statistics'].type_name)
Example #28
0
 def testConstructWithParameter(self):
     column_name = data_types.RuntimeParameter(name='column-name',
                                               ptype=Text)
     threshold = data_types.RuntimeParameter(name='threshold', ptype=float)
     examples = standard_artifacts.Examples()
     model_exports = standard_artifacts.Model()
     evaluator = component.Evaluator(
         examples=channel_utils.as_channel([examples]),
         model=channel_utils.as_channel([model_exports]),
         feature_slicing_spec={
             'specs': [{
                 'column_for_slicing': [column_name]
             }]
         },
         fairness_indicator_thresholds=[threshold])
     self.assertEqual(standard_artifacts.ModelEvaluation.TYPE_NAME,
                      evaluator.outputs['evaluation'].type_name)
Example #29
0
    def testConstructHybridPipelineConfiguration(self):
        to_key_fn = """
      def to_key(m):
        return m.features.feature[key].float_list.value[0] > 0.5
    """

        examples = standard_artifacts.Examples()
        stratified_sampler = StratifiedSampler(
            examples=channel_utils.as_channel([examples]),
            pipeline_configuration=types.Channel(type=PipelineConfiguration),
            stratified_examples=channel_utils.as_channel(
                [standard_artifacts.Examples()]),
            to_key_fn=to_key_fn,
            samples_per_key=112)
        self.assertEqual(
            'Examples',
            stratified_sampler.outputs[STRATIFIED_EXAMPLES_KEY].type_name)
Example #30
0
    def __init__(
            self,
            model_export: types.Channel,
            model_blessing: types.Channel,
            push_destination: Optional[pusher_pb2.PushDestination] = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            executor_class: Optional[Type[base_executor.BaseExecutor]] = None,
            model_push: Optional[types.Channel] = None,
            name: Optional[Text] = None):
        """Construct a Pusher component.

    Args:
      model_export: A Channel of 'ModelExportPath' type, usually produced by
        Trainer component.
      model_blessing: A Channel of 'ModelBlessingPath' type, usually produced by
        ModelValidator component.
      push_destination: A pusher_pb2.PushDestination instance, providing
        info for tensorflow serving to load models. Optional if executor_class
        doesn't require push_destination.
      custom_config: A dict which contains the deployment job parameters to be
        passed to Google Cloud ML Engine.  For the full set of parameters
        supported by Google Cloud ML Engine, refer to
        https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
      executor_class: Optional custom python executor class.
      model_push: Optional output 'ModelPushPath' channel with result of push.
      name: Optional unique name. Necessary if multiple Pusher components are
        declared in the same pipeline.
    """
        model_push = model_push or types.Channel(
            type=standard_artifacts.PushedModel,
            artifacts=[standard_artifacts.PushedModel()])
        if push_destination is None and not executor_class:
            raise ValueError(
                'push_destination is required unless a custom '
                'executor_class is supplied that does not require '
                'it.')
        spec = PusherSpec(
            model_export=channel_utils.as_channel(model_export),
            model_blessing=channel_utils.as_channel(model_blessing),
            push_destination=push_destination,
            custom_config=custom_config,
            model_push=model_push)
        super(Pusher, self).__init__(spec=spec,
                                     custom_executor_class=executor_class,
                                     name=name)