def testPrepareOutputArtifacts(self): examples = standard_artifacts.Examples() output_dict = { standard_component_specs.EXAMPLES_KEY: channel_utils.as_channel([examples]) } exec_properties = { utils.SPAN_PROPERTY_NAME: 2, utils.VERSION_PROPERTY_NAME: 1, utils.FINGERPRINT_PROPERTY_NAME: 'fp' } pipeline_info = data_types.PipelineInfo( pipeline_name='name', pipeline_root=self._test_dir, run_id='rid') component_info = data_types.ComponentInfo( component_type='type', component_id='cid', pipeline_info=pipeline_info) input_artifacts = {} output_artifacts = self._file_based_driver._prepare_output_artifacts( # pylint: disable=protected-access input_artifacts, output_dict, exec_properties, 1, pipeline_info, component_info) examples = artifact_utils.get_single_instance( output_artifacts[standard_component_specs.EXAMPLES_KEY]) base_output_dir = os.path.join(self._test_dir, component_info.component_id) expected_uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, 'examples', 1) self.assertEqual(examples.uri, expected_uri) self.assertEqual( examples.get_string_custom_property(utils.FINGERPRINT_PROPERTY_NAME), 'fp') self.assertEqual( examples.get_int_custom_property(utils.SPAN_PROPERTY_NAME), 2) self.assertEqual( examples.get_int_custom_property(utils.VERSION_PROPERTY_NAME), 1)
def test_channel_utils_as_channel_success(self): instance_a = Artifact('MyTypeName') instance_b = Artifact('MyTypeName') chnl_original = Channel('MyTypeName', artifacts=[instance_a, instance_b]) chnl_result = channel_utils.as_channel(chnl_original) self.assertEqual(chnl_original, chnl_result)
def testConstruct(self): statistics_artifact = standard_artifacts.ExampleStatistics() statistics_artifact.split_names = artifact_utils.encode_split_names( ['train', 'eval']) exclude_splits = ['eval'] example_validator = component.ExampleValidator( statistics=channel_utils.as_channel([statistics_artifact]), schema=channel_utils.as_channel([standard_artifacts.Schema()]), exclude_splits=exclude_splits) self.assertEqual( standard_artifacts.ExampleAnomalies.TYPE_NAME, example_validator.outputs[ standard_component_specs.ANOMALIES_KEY].type_name) self.assertEqual( example_validator.spec.exec_properties[ standard_component_specs.EXCLUDE_SPLITS_KEY], '["eval"]')
def testConstructInferenceResultAndOutputExample(self): with self.assertRaises(ValueError): component.BulkInferrer(examples=self._examples, model=self._model, model_blessing=self._model_blessing, output_examples=channel_utils.as_channel( [standard_artifacts.Examples()])) with self.assertRaises(ValueError): component.BulkInferrer( examples=self._examples, model=self._model, model_blessing=self._model_blessing, output_example_spec=bulk_inferrer_pb2.OutputExampleSpec(), inference_result=channel_utils.as_channel( [standard_artifacts.InferenceResult()]))
def testPrepareOutputArtifacts(self): examples = standard_artifacts.Examples() output_dict = { utils.EXAMPLES_KEY: channel_utils.as_channel([examples]) } exec_properties = { utils.SPAN_PROPERTY_NAME: '02', utils.FINGERPRINT_PROPERTY_NAME: 'fp' } pipeline_info = data_types.PipelineInfo(pipeline_name='name', pipeline_root=self._test_dir, run_id='rid') component_info = data_types.ComponentInfo(component_type='type', component_id='cid', pipeline_info=pipeline_info) output_artifacts = self._example_gen_driver._prepare_output_artifacts( output_dict, exec_properties, 1, pipeline_info, component_info) examples = artifact_utils.get_single_instance( output_artifacts[utils.EXAMPLES_KEY]) self.assertEqual(examples.uri, os.path.join(self._test_dir, 'cid', 'examples', '1')) self.assertEqual( examples.get_string_custom_property( utils.FINGERPRINT_PROPERTY_NAME), 'fp') self.assertEqual( examples.get_string_custom_property(utils.SPAN_PROPERTY_NAME), '02')
def __init__(self, *unused_args, **kwargs): if unused_args: raise ValueError(('%s expects arguments to be passed as keyword ' 'arguments') % (self.__class__.__name__,)) spec_kwargs = {} unseen_args = set(kwargs.keys()) for key, channel_parameter in self.SPEC_CLASS.INPUTS.items(): if key not in kwargs and not channel_parameter.optional: raise ValueError('%s expects input %r to be a Channel of type %s.' % (self.__class__.__name__, key, channel_parameter.type)) if key in kwargs: spec_kwargs[key] = kwargs[key] unseen_args.remove(key) for key, parameter in self.SPEC_CLASS.PARAMETERS.items(): if key not in kwargs and not parameter.optional: raise ValueError('%s expects parameter %r of type %s.' % (self.__class__.__name__, key, parameter.type)) if key in kwargs: spec_kwargs[key] = kwargs[key] unseen_args.remove(key) instance_name = kwargs.get('instance_name', None) unseen_args.discard('instance_name') if unseen_args: raise ValueError( 'Unknown arguments to %r: %s.' % (self.__class__.__name__, ', '.join(sorted(unseen_args)))) for key, channel_parameter in self.SPEC_CLASS.OUTPUTS.items(): spec_kwargs[key] = channel_utils.as_channel([channel_parameter.type()]) spec = self.SPEC_CLASS(**spec_kwargs) super(_SimpleComponent, self).__init__(spec, instance_name=instance_name)
def testConstructNoPipelineConfiguration(self): examples = standard_artifacts.Examples() predicate_fn = """ def predicate(m): return m.features.feature[key].float_list.value[0] > 0.5 """ filter = Filter(pipeline_configuration=None, examples=channel_utils.as_channel([examples]), filtered_examples=channel_utils.as_channel( [standard_artifacts.Examples()]), splits_to_transform=['eval'], splits_to_copy=['train'], predicate_fn=predicate_fn) self.assertEqual('Examples', filter.outputs[FILTERED_EXAMPLES_KEY].type_name)
def testConstruct(self): input_data = standard_artifacts.Examples() input_data.split_names = json.dumps(artifact.DEFAULT_EXAMPLE_SPLITS) output_data = standard_artifacts.Examples() output_data.split_names = json.dumps(artifact.DEFAULT_EXAMPLE_SPLITS) this_component = component.HelloComponent( input_data=channel_utils.as_channel([input_data]), output_data=channel_utils.as_channel([output_data]), name=u'Testing123') self.assertEqual(standard_artifacts.Examples.TYPE_NAME, this_component.outputs['output_data'].type_name) artifact_collection = this_component.outputs['output_data'].get() for artifacts in artifact_collection: split_list = json.loads(artifacts.split_names) self.assertEqual(artifact.DEFAULT_EXAMPLE_SPLITS.sort(), split_list.sort())
def __init__(self, function_name: Text = None, model: types.Channel = None, model_blessing: Optional[types.Channel] = None, infra_blessing: Optional[types.Channel] = None, pushed_model: Optional[types.Channel] = None, output: types.Channel = None, pipeline_configuration: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None): """Construct a model export component. Args: function_name: The instance_name of the function to apply on the model. model: A Channel of type `standard_artifacts.Model`. model_blessing: A Channel of type `standard_artifacts.ModelBlessing`. infra_blessing: A Channel of type `standard_artifacts.InfraBlessing`. pushed_model: A Channel of type `standard_artifacts.PushedModel`. output: A Channel of type `ExportedModel`. pipeline_configuration: A Channel of 'PipelineConfiguration' type, usually produced by FromCustomConfig component. transform_graph: A channel of type `standard_artifacts.TransformGraph`. """ if not output: output = channel_utils.as_channel([ExportedModel()]) spec = ExportSpec(function_name=function_name, pipeline_configuration=pipeline_configuration, model=model, model_blessing=model_blessing, infra_blessing=infra_blessing, pushed_model=pushed_model, output=output, transform_graph=transform_graph) super(Export, self).__init__(spec=spec)
def __init__(self, input_data: types.Channel = None, output_data: types.Channel = None, name: Optional[Text] = None): """Construct a HelloComponent. Args: input_data: A Channel of type `standard_artifacts.Examples`. This will often contain two splits: 'train', and 'eval'. output_data: A Channel of type `standard_artifacts.Examples`. This will usually contain the same splits as input_data. name: Optional unique name. Necessary if multiple Hello components are declared in the same pipeline. """ # output_data will contain a list of Channels for each split of the data, # by default a 'train' split and an 'eval' split. Since HelloComponent # passes the input data through to output, the splits in output_data will # be the same as the splits in input_data, which were generated by the # upstream component. if not output_data: output_data = channel_utils.as_channel( [standard_artifacts.Examples()]) spec = HelloComponentSpec(input_data=input_data, output_data=output_data, name=name) super(HelloComponent, self).__init__(spec=spec)
def testConstructWithFairnessThresholds(self): examples = standard_artifacts.Examples() model_exports = standard_artifacts.Model() evaluator = component.Evaluator( examples=channel_utils.as_channel([examples]), model=channel_utils.as_channel([model_exports]), feature_slicing_spec=evaluator_pb2.FeatureSlicingSpec(specs=[ evaluator_pb2.SingleSlicingSpec( column_for_slicing=['trip_start_hour']) ]), fairness_indicator_thresholds=[0.1, 0.3, 0.5, 0.9]) self.assertEqual(standard_artifacts.ModelEvaluation.TYPE_NAME, evaluator.outputs['evaluation'].type_name) self.assertEqual( '[0.1, 0.3, 0.5, 0.9]', evaluator.exec_properties['fairness_indicator_thresholds'])
def testFindComponentLaunchInfoReturnConfigOverride(self): input_artifact = test_utils._InputArtifact() component = test_utils._FakeComponent( name='FakeComponent', input_channel=channel_utils.as_channel([input_artifact]), custom_executor_spec=executor_spec.ExecutorContainerSpec( image='gcr://test', args=['{{input_dict["input"][0].uri}}'])) default_config = docker_component_config.DockerComponentConfig() override_config = docker_component_config.DockerComponentConfig( name='test') p_config = pipeline_config.PipelineConfig( supported_launcher_classes=[ docker_component_launcher.DockerComponentLauncher ], default_component_configs=[default_config], component_config_overrides={ '_FakeComponent.FakeComponent': override_config }) (launcher_class, c_config) = config_utils.find_component_launch_info( p_config, component) self.assertEqual(docker_component_launcher.DockerComponentLauncher, launcher_class) self.assertEqual(override_config, c_config)
def testConstruct(self): schema_gen = component.SchemaGen( stats=channel_utils.as_channel( [standard_artifacts.ExampleStatistics(split='train')]), infer_feature_shape=True) self.assertEqual('SchemaPath', schema_gen.outputs.output.type_name) self.assertTrue(schema_gen.spec.exec_properties['infer_feature_shape'])
def testConstruct(self): train_examples = standard_artifacts.Examples(split='train') eval_examples = standard_artifacts.Examples(split='eval') statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([train_examples, eval_examples])) self.assertEqual('ExampleStatisticsPath', statistics_gen.outputs['statistics'].type_name)
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input_base=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( input_data=example_gen.outputs.examples) pipeline = tfx_pipeline.Pipeline( pipeline_name='test_pipeline', pipeline_root='test_pipeline_root', components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, depends_on=set([example_gen]), pipeline=pipeline, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, )
def testArtifactCollectionAsChannel(self): instance_a = _MyArtifact() instance_b = _MyArtifact() chnl = channel_utils.as_channel([instance_a, instance_b]) self.assertEqual(chnl.type, _MyArtifact) self.assertEqual(chnl.type_name, 'MyTypeName') self.assertCountEqual(chnl.get(), [instance_a, instance_b])
def setUp(self): super(BaseComponentWithPipelineParamTest, self).setUp() test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') example_gen_buckets = data_types.RuntimeParameter( name='example-gen-buckets', ptype=int, default=10) examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples]), output_config={ 'split_config': { 'splits': [{ 'name': 'examples', 'hash_buckets': example_gen_buckets }] } }) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' self._tfx_ir = pipeline_pb2.Pipeline() with dsl.Pipeline('test_pipeline'): self.example_gen = base_component.BaseComponent( component=example_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir) self.statistics_gen = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, tfx_ir=self._tfx_ir ) self.tfx_example_gen = example_gen self.tfx_statistics_gen = statistics_gen
def setUp(self): super(DriverTest, self).setUp() # Create input splits. test_dir = os.path.join( os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), self._testMethodName) self._input_base_path = os.path.join(test_dir, 'input_base') tf.io.gfile.makedirs(self._input_base_path) # Mock metadata. self._mock_metadata = tf.compat.v1.test.mock.Mock() self._example_gen_driver = driver.Driver(self._mock_metadata) # Create input dict. input_base = standard_artifacts.ExternalArtifact() input_base.uri = self._input_base_path self._input_channels = { 'input_base': channel_utils.as_channel([input_base]) } # Create exec proterties. self._exec_properties = { 'input_config': json_format.MessageToJson(example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*'), example_gen_pb2.Input.Split(name='s2', pattern='span{SPAN}/split2/*') ]), preserving_proto_field_name=True), }
def setUp(self): super(BaseComponentTest, self).setUp() examples = standard_artifacts.ExternalArtifact() example_gen = csv_example_gen_component.CsvExampleGen( input=channel_utils.as_channel([examples])) statistics_gen = statistics_gen_component.StatisticsGen( examples=example_gen.outputs['examples'], instance_name='foo') pipeline = tfx_pipeline.Pipeline( pipeline_name=self._test_pipeline_name, pipeline_root='test_pipeline_root', metadata_connection_config=metadata_store_pb2.ConnectionConfig(), components=[example_gen, statistics_gen], ) test_pipeline_root = dsl.PipelineParam(name='pipeline-root-param') self._metadata_config = kubeflow_pb2.KubeflowMetadataConfig() self._metadata_config.mysql_db_service_host.environment_variable = 'MYSQL_SERVICE_HOST' with dsl.Pipeline('test_pipeline'): self.component = base_component.BaseComponent( component=statistics_gen, component_launcher_class=in_process_component_launcher .InProcessComponentLauncher, depends_on=set(), pipeline=pipeline, pipeline_name=self._test_pipeline_name, pipeline_root=test_pipeline_root, tfx_image='container_image', kubeflow_metadata_config=self._metadata_config, component_config=None, ) self.tfx_component = statistics_gen
def pre_execution( self, input_dict: Dict[Text, types.Channel], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], driver_args: data_types.DriverArgs, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> data_types.ExecutionDecision: output_artifacts = { IMPORT_RESULT_KEY: self._import_artifacts( source_uri=exec_properties[SOURCE_URI_KEY], destination_channel=output_dict[IMPORT_RESULT_KEY], reimport=exec_properties[REIMPORT_OPTION_KEY], split_names=exec_properties[SPLIT_KEY]) } output_dict[IMPORT_RESULT_KEY] = channel_utils.as_channel( output_artifacts[IMPORT_RESULT_KEY]) return data_types.ExecutionDecision( input_dict={}, output_dict=output_artifacts, exec_properties={}, execution_id=self._register_execution( exec_properties={}, pipeline_info=pipeline_info, component_info=component_info), use_cached_results=False)
def __init__(self, *unused_args, **kwargs): if unused_args: raise ValueError(('%s expects arguments to be passed as keyword ' 'arguments') % (self.__class__.__name__, )) spec_kwargs = {} unseen_args = set(kwargs.keys()) for key, channel_parameter in self.SPEC_CLASS.INPUTS.items(): if key not in kwargs and not channel_parameter.optional: raise ValueError( '%s expects input %r to be a Channel of type %s.' % (self.__class__.__name__, key, channel_parameter.type)) if key in kwargs: spec_kwargs[key] = kwargs[key] unseen_args.remove(key) for key, parameter in self.SPEC_CLASS.PARAMETERS.items(): if key not in kwargs and not parameter.optional: raise ValueError( '%s expects parameter %r of type %s.' % (self.__class__.__name__, key, parameter.type)) if key in kwargs: spec_kwargs[key] = kwargs[key] unseen_args.remove(key) if unseen_args: raise ValueError( 'Unknown arguments to %r: %s.' % (self.__class__.__name__, ', '.join(sorted(unseen_args)))) for key, channel_parameter in self.SPEC_CLASS.OUTPUTS.items(): spec_kwargs[key] = channel_utils.as_channel( [channel_parameter.type()]) spec = self.SPEC_CLASS(**spec_kwargs) super().__init__(spec) # Set class name, which is the decorated function name, as the default id. # It can be overwritten by the user. self._id = self.__class__.__name__
def _create_launcher_context(self, component_config=None): test_dir = self.get_temp_dir() connection_config = metadata_store_pb2.ConnectionConfig() connection_config.sqlite.SetInParent() metadata_connection = metadata.Metadata(connection_config) pipeline_root = os.path.join(test_dir, 'Test') input_artifact = test_utils._InputArtifact() input_artifact.uri = os.path.join(test_dir, 'input') component = test_utils._FakeComponent( name='FakeComponent', input_channel=channel_utils.as_channel([input_artifact]), custom_executor_spec=executor_spec.ExecutorContainerSpec( image='gcr://test', args=['{{input_dict["input"][0].uri}}'])) pipeline_info = data_types.PipelineInfo(pipeline_name='Test', pipeline_root=pipeline_root, run_id='123') driver_args = data_types.DriverArgs(enable_cache=True) launcher = kubernetes_component_launcher.KubernetesComponentLauncher.create( component=component, pipeline_info=pipeline_info, driver_args=driver_args, metadata_connection=metadata_connection, beam_pipeline_args=[], additional_pipeline_args={}, component_config=component_config) return {'launcher': launcher, 'input_artifact': input_artifact}
def testConstruct(self): train_examples_in = standard_artifacts.Examples(split='train') eval_examples_in = standard_artifacts.Examples(split='eval') train_examples_out = standard_artifacts.Examples(split='train') eval_examples_out = standard_artifacts.Examples(split='eval') this_component = component.HelloComponent( input_data=channel_utils.as_channel( [train_examples_in, eval_examples_in]), output_data=channel_utils.as_channel( [train_examples_out, eval_examples_out]), name=u'Testing123') self.assertEqual('ExamplesPath', this_component.outputs['output_data'].type_name) artifact_collection = this_component.outputs['output_data'].get() self.assertEqual('train', artifact_collection[0].split) self.assertEqual('eval', artifact_collection[1].split)
def test_construct_with_cache_disabled_but_input_cache(self): with self.assertRaises(ValueError): _ = component.Transform(examples=self.examples, schema=self.schema, preprocessing_fn='my_preprocessing_fn', disable_analyzer_cache=True, analyzer_cache=channel_utils.as_channel( [standard_artifacts.TransformCache()]))
def testEnableCache(self): input_base = standard_artifacts.ExternalArtifact() custom_config = example_gen_pb2.CustomConfig( custom_config=any_pb2.Any()) example_gen_1 = component.FileBasedExampleGen( input=channel_utils.as_channel([input_base]), custom_config=custom_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) self.assertEqual(None, example_gen_1.enable_cache) example_gen_2 = component.FileBasedExampleGen( input=channel_utils.as_channel([input_base]), custom_config=custom_config, custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor), enable_cache=True) self.assertEqual(True, example_gen_2.enable_cache)
def testInvalidBenchmarkNameThrows(self): with self.assertRaises(ValueError): BenchmarkResultPublisher('', channel_utils.as_channel([ standard_artifacts.ModelEvaluation() ]), run=1, num_runs=2)
def testConstruct(self): examples = standard_artifacts.Examples() examples.split_names = artifact_utils.encode_split_names( ['train', 'eval']) statistics_gen = component.StatisticsGen( examples=channel_utils.as_channel([examples])) self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME, statistics_gen.outputs['statistics'].type_name)
def testConstructWithParameter(self): column_name = data_types.RuntimeParameter(name='column-name', ptype=Text) threshold = data_types.RuntimeParameter(name='threshold', ptype=float) examples = standard_artifacts.Examples() model_exports = standard_artifacts.Model() evaluator = component.Evaluator( examples=channel_utils.as_channel([examples]), model=channel_utils.as_channel([model_exports]), feature_slicing_spec={ 'specs': [{ 'column_for_slicing': [column_name] }] }, fairness_indicator_thresholds=[threshold]) self.assertEqual(standard_artifacts.ModelEvaluation.TYPE_NAME, evaluator.outputs['evaluation'].type_name)
def testConstructHybridPipelineConfiguration(self): to_key_fn = """ def to_key(m): return m.features.feature[key].float_list.value[0] > 0.5 """ examples = standard_artifacts.Examples() stratified_sampler = StratifiedSampler( examples=channel_utils.as_channel([examples]), pipeline_configuration=types.Channel(type=PipelineConfiguration), stratified_examples=channel_utils.as_channel( [standard_artifacts.Examples()]), to_key_fn=to_key_fn, samples_per_key=112) self.assertEqual( 'Examples', stratified_sampler.outputs[STRATIFIED_EXAMPLES_KEY].type_name)
def __init__( self, model_export: types.Channel, model_blessing: types.Channel, push_destination: Optional[pusher_pb2.PushDestination] = None, custom_config: Optional[Dict[Text, Any]] = None, executor_class: Optional[Type[base_executor.BaseExecutor]] = None, model_push: Optional[types.Channel] = None, name: Optional[Text] = None): """Construct a Pusher component. Args: model_export: A Channel of 'ModelExportPath' type, usually produced by Trainer component. model_blessing: A Channel of 'ModelBlessingPath' type, usually produced by ModelValidator component. push_destination: A pusher_pb2.PushDestination instance, providing info for tensorflow serving to load models. Optional if executor_class doesn't require push_destination. custom_config: A dict which contains the deployment job parameters to be passed to Google Cloud ML Engine. For the full set of parameters supported by Google Cloud ML Engine, refer to https://cloud.google.com/ml-engine/reference/rest/v1/projects.models executor_class: Optional custom python executor class. model_push: Optional output 'ModelPushPath' channel with result of push. name: Optional unique name. Necessary if multiple Pusher components are declared in the same pipeline. """ model_push = model_push or types.Channel( type=standard_artifacts.PushedModel, artifacts=[standard_artifacts.PushedModel()]) if push_destination is None and not executor_class: raise ValueError( 'push_destination is required unless a custom ' 'executor_class is supplied that does not require ' 'it.') spec = PusherSpec( model_export=channel_utils.as_channel(model_export), model_blessing=channel_utils.as_channel(model_blessing), push_destination=push_destination, custom_config=custom_config, model_push=model_push) super(Pusher, self).__init__(spec=spec, custom_executor_class=executor_class, name=name)