コード例 #1
0
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root_1: str,
                     data_root_2: str) -> pipeline.Pipeline:
    """Implements a pipeline with channel.union()."""
    # Brings data into the pipeline or otherwise joins/converts training data.
    example_gen_1 = CsvExampleGen(
        input_base=data_root_1).with_id('example_gen_1')
    example_gen_2 = CsvExampleGen(
        input_base=data_root_2).with_id('example_gen_2')

    # pylint: disable=no-value-for-parameter
    channel_union = ChannelUnionComponent(input_data=channel.union(
        [example_gen_1.outputs['examples'],
         example_gen_2.outputs['examples']]),
                                          name='channel_union_input')

    # Get the latest channel.
    latest_artifacts_resolver = resolver.Resolver(
        strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
        resolved_channels=channel.union([
            example_gen_1.outputs['examples'],
            channel_union.outputs['output_data']
        ])).with_id('latest_artifacts_resolver')

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(
        examples=latest_artifacts_resolver.outputs['resolved_channels'])
    return pipeline.Pipeline(pipeline_name=pipeline_name,
                             pipeline_root=pipeline_root,
                             components=[
                                 example_gen_1, example_gen_2, channel_union,
                                 latest_artifacts_resolver, statistics_gen
                             ])
コード例 #2
0
  def testValidUnionChannel(self):
    channel1 = channel.Channel(type=_MyType)
    channel2 = channel.Channel(type=_MyType)
    union_channel = channel.union([channel1, channel2])
    self.assertIs(union_channel.type_name, 'MyTypeName')
    self.assertEqual(union_channel.channels, [channel1, channel2])

    union_channel = channel.union([channel1, channel.union([channel2])])
    self.assertIs(union_channel.type_name, 'MyTypeName')
    self.assertEqual(union_channel.channels, [channel1, channel2])
コード例 #3
0
    def testResolveInputArtifacts(self):
        artifact_1 = standard_artifacts.String()
        artifact_1.id = 1
        channel_1 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c1').set_artifacts(
                                      [artifact_1])
        artifact_2 = standard_artifacts.String()
        artifact_2.id = 2
        channel_2 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c2').set_artifacts(
                                      [artifact_2])
        channel_3 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c3').set_artifacts(
                                      [standard_artifacts.String()])
        input_dict = {
            'input_union': channel.union([channel_1, channel_2]),
            'input_string': channel_3,
        }
        self._mock_metadata.search_artifacts.side_effect = [
            channel_3.get(), channel_1.get(),
            channel_2.get()
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        resolved_artifacts = driver.resolve_input_artifacts(
            input_dict=input_dict,
            exec_properties=self._exec_properties,
            driver_args=self._driver_args,
            pipeline_info=self._pipeline_info)
        self.assertEqual(len(resolved_artifacts['input_union']), 2)
        self.assertEqual(resolved_artifacts['input_union'][0].value,
                         _STRING_VALUE)
        self.assertEqual(len(resolved_artifacts['input_string']), 1)
        self.assertEqual(resolved_artifacts['input_string'][0].value,
                         _STRING_VALUE)
コード例 #4
0
ファイル: channel_utils_test.py プロジェクト: jay90099/tfx
    def testGetInidividualChannels(self):
        instance_a = _MyArtifact()
        instance_b = _MyArtifact()
        one_channel = channel.Channel(_MyArtifact).set_artifacts([instance_a])
        another_channel = channel.Channel(_MyArtifact).set_artifacts(
            [instance_b])

        result = channel_utils.get_individual_channels(one_channel)
        self.assertEqual(result, [one_channel])

        result = channel_utils.get_individual_channels(
            channel.union([one_channel, another_channel]))
        self.assertEqual(result, [one_channel, another_channel])
コード例 #5
0
ファイル: component_spec_test.py プロジェクト: jay90099/tfx
    def testComponentSpec_WithUnionChannel(self):
        input_channel_1 = channel.Channel(type=_InputArtifact)
        input_channel_2 = channel.Channel(type=_InputArtifact)
        output_channel = channel.Channel(type=_OutputArtifact)
        spec = _BasicComponentSpec(folds=10,
                                   input=channel.union(
                                       [input_channel_1, input_channel_2]),
                                   output=output_channel)

        # Verify properties.
        self.assertEqual(10, spec.exec_properties['folds'])
        self.assertEqual(spec.inputs['input'].type, _InputArtifact)
        self.assertEqual(spec.inputs['input'].channels,
                         [input_channel_1, input_channel_2])
        self.assertIs(spec.outputs['output'], output_channel)
コード例 #6
0
ファイル: resolver_test.py プロジェクト: jay90099/tfx
 def testResolverUnionChannel(self):
     one_channel = types.Channel(type=standard_artifacts.Examples)
     another_channel = types.Channel(type=standard_artifacts.Examples)
     unioned_channel = channel.union([one_channel, another_channel])
     rnode = resolver.Resolver(
         strategy_class=latest_artifact_strategy.LatestArtifactStrategy,
         config={'desired_num_of_artifacts': 5},
         unioned_channel=unioned_channel)
     self.assertDictEqual(
         rnode.exec_properties, {
             resolver.RESOLVER_STRATEGY_CLASS:
             latest_artifact_strategy.LatestArtifactStrategy,
             resolver.RESOLVER_CONFIG: {
                 'desired_num_of_artifacts': 5
             }
         })
     self.assertEqual(rnode.inputs['unioned_channel'], unioned_channel)
     self.assertEqual(rnode.outputs['unioned_channel'].type_name,
                      unioned_channel.type_name)
コード例 #7
0
 def testEmptyUnionChannel(self):
   with self.assertRaises(AssertionError):
     channel.union([])
コード例 #8
0
 def testMismatchedUnionChannelType(self):
   chnl = channel.Channel(type=_MyType)
   another_channel = channel.Channel(type=_AnotherType)
   with self.assertRaises(TypeError):
     channel.union([chnl, another_channel])