def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root_1: str, data_root_2: str) -> pipeline.Pipeline: """Implements a pipeline with channel.union().""" # Brings data into the pipeline or otherwise joins/converts training data. example_gen_1 = CsvExampleGen( input_base=data_root_1).with_id('example_gen_1') example_gen_2 = CsvExampleGen( input_base=data_root_2).with_id('example_gen_2') # pylint: disable=no-value-for-parameter channel_union = ChannelUnionComponent(input_data=channel.union( [example_gen_1.outputs['examples'], example_gen_2.outputs['examples']]), name='channel_union_input') # Get the latest channel. latest_artifacts_resolver = resolver.Resolver( strategy_class=latest_artifact_strategy.LatestArtifactStrategy, resolved_channels=channel.union([ example_gen_1.outputs['examples'], channel_union.outputs['output_data'] ])).with_id('latest_artifacts_resolver') # Computes statistics over data for visualization and example validation. statistics_gen = StatisticsGen( examples=latest_artifacts_resolver.outputs['resolved_channels']) return pipeline.Pipeline(pipeline_name=pipeline_name, pipeline_root=pipeline_root, components=[ example_gen_1, example_gen_2, channel_union, latest_artifacts_resolver, statistics_gen ])
def testValidUnionChannel(self): channel1 = channel.Channel(type=_MyType) channel2 = channel.Channel(type=_MyType) union_channel = channel.union([channel1, channel2]) self.assertIs(union_channel.type_name, 'MyTypeName') self.assertEqual(union_channel.channels, [channel1, channel2]) union_channel = channel.union([channel1, channel.union([channel2])]) self.assertIs(union_channel.type_name, 'MyTypeName') self.assertEqual(union_channel.channels, [channel1, channel2])
def testResolveInputArtifacts(self): artifact_1 = standard_artifacts.String() artifact_1.id = 1 channel_1 = types.Channel(type=standard_artifacts.String, producer_component_id='c1').set_artifacts( [artifact_1]) artifact_2 = standard_artifacts.String() artifact_2.id = 2 channel_2 = types.Channel(type=standard_artifacts.String, producer_component_id='c2').set_artifacts( [artifact_2]) channel_3 = types.Channel(type=standard_artifacts.String, producer_component_id='c3').set_artifacts( [standard_artifacts.String()]) input_dict = { 'input_union': channel.union([channel_1, channel_2]), 'input_string': channel_3, } self._mock_metadata.search_artifacts.side_effect = [ channel_3.get(), channel_1.get(), channel_2.get() ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) resolved_artifacts = driver.resolve_input_artifacts( input_dict=input_dict, exec_properties=self._exec_properties, driver_args=self._driver_args, pipeline_info=self._pipeline_info) self.assertEqual(len(resolved_artifacts['input_union']), 2) self.assertEqual(resolved_artifacts['input_union'][0].value, _STRING_VALUE) self.assertEqual(len(resolved_artifacts['input_string']), 1) self.assertEqual(resolved_artifacts['input_string'][0].value, _STRING_VALUE)
def testGetInidividualChannels(self): instance_a = _MyArtifact() instance_b = _MyArtifact() one_channel = channel.Channel(_MyArtifact).set_artifacts([instance_a]) another_channel = channel.Channel(_MyArtifact).set_artifacts( [instance_b]) result = channel_utils.get_individual_channels(one_channel) self.assertEqual(result, [one_channel]) result = channel_utils.get_individual_channels( channel.union([one_channel, another_channel])) self.assertEqual(result, [one_channel, another_channel])
def testComponentSpec_WithUnionChannel(self): input_channel_1 = channel.Channel(type=_InputArtifact) input_channel_2 = channel.Channel(type=_InputArtifact) output_channel = channel.Channel(type=_OutputArtifact) spec = _BasicComponentSpec(folds=10, input=channel.union( [input_channel_1, input_channel_2]), output=output_channel) # Verify properties. self.assertEqual(10, spec.exec_properties['folds']) self.assertEqual(spec.inputs['input'].type, _InputArtifact) self.assertEqual(spec.inputs['input'].channels, [input_channel_1, input_channel_2]) self.assertIs(spec.outputs['output'], output_channel)
def testResolverUnionChannel(self): one_channel = types.Channel(type=standard_artifacts.Examples) another_channel = types.Channel(type=standard_artifacts.Examples) unioned_channel = channel.union([one_channel, another_channel]) rnode = resolver.Resolver( strategy_class=latest_artifact_strategy.LatestArtifactStrategy, config={'desired_num_of_artifacts': 5}, unioned_channel=unioned_channel) self.assertDictEqual( rnode.exec_properties, { resolver.RESOLVER_STRATEGY_CLASS: latest_artifact_strategy.LatestArtifactStrategy, resolver.RESOLVER_CONFIG: { 'desired_num_of_artifacts': 5 } }) self.assertEqual(rnode.inputs['unioned_channel'], unioned_channel) self.assertEqual(rnode.outputs['unioned_channel'].type_name, unioned_channel.type_name)
def testEmptyUnionChannel(self): with self.assertRaises(AssertionError): channel.union([])
def testMismatchedUnionChannelType(self): chnl = channel.Channel(type=_MyType) another_channel = channel.Channel(type=_AnotherType) with self.assertRaises(TypeError): channel.union([chnl, another_channel])