Пример #1
0
    def testChannelParameterType(self):
        arg_name = 'foo'

        class _FooArtifact(artifact.Artifact):
            TYPE_NAME = 'FooArtifact'

        class _BarArtifact(artifact.Artifact):
            TYPE_NAME = 'BarArtifact'

        channel_parameter = ChannelParameter(type=_FooArtifact)
        # Following should pass.
        channel_parameter.type_check(arg_name,
                                     channel.Channel(type=_FooArtifact))

        with self.assertRaisesRegex(TypeError, arg_name):
            channel_parameter.type_check(arg_name, 42)  # Wrong value.

        with self.assertRaisesRegex(TypeError, arg_name):
            channel_parameter.type_check(arg_name,
                                         channel.Channel(type=_BarArtifact))

        setattr(_FooArtifact, component_spec.COMPATIBLE_TYPES_KEY,
                {_BarArtifact})
        channel_parameter.type_check(arg_name,
                                     channel.Channel(type=_BarArtifact))
Пример #2
0
  def testBuildLatestBlessedModelResolverSucceed(self):

    latest_blessed_resolver = components.ResolverNode(
        instance_name='my_resolver2',
        resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
        model=channel.Channel(type=standard_artifacts.Model),
        model_blessing=channel.Channel(type=standard_artifacts.ModelBlessing))
    test_pipeline_info = data_types.PipelineInfo(
        pipeline_name='test-pipeline', pipeline_root='gs://path/to/my/root')

    deployment_config = pipeline_pb2.PipelineDeploymentConfig()
    my_builder = step_builder.StepBuilder(
        node=latest_blessed_resolver,
        deployment_config=deployment_config,
        pipeline_info=test_pipeline_info)
    actual_step_specs = my_builder.build()

    self.assertProtoEquals(
        text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_1,
                          pipeline_pb2.PipelineTaskSpec()),
        actual_step_specs[0])

    self.assertProtoEquals(
        text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_RESOLVER_2,
                          pipeline_pb2.PipelineTaskSpec()),
        actual_step_specs[1])

    self.assertProtoEquals(
        text_format.Parse(_EXPECTED_LATEST_BLESSED_MODEL_EXECUTOR,
                          pipeline_pb2.PipelineDeploymentConfig()),
        deployment_config)
Пример #3
0
    def testBuildLatestArtifactResolverSucceed(self):
        latest_model_resolver = resolver.Resolver(
            instance_name='my_resolver',
            strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            examples=channel.Channel(type=standard_artifacts.Examples))
        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')
        my_builder = step_builder.StepBuilder(
            node=latest_model_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_spec = self._sole(my_builder.build())
        actual_component_def = self._sole(component_defs)

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_component.pbtxt',
                pipeline_pb2.ComponentSpec()), actual_component_def)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_task.pbtxt',
                pipeline_pb2.PipelineTaskSpec()), actual_step_spec)
        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_artifact_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Пример #4
0
 def testComponentSpec_JsonProto(self):
     proto_str = '{"splits": [{"name": "name1", "pattern": "pattern1"}]}'
     spec = _BasicComponentSpec(
         folds=10,
         proto=proto_str,
         input=channel.Channel(type=_InputArtifact),
         output=channel.Channel(type=_OutputArtifact))
     self.assertIsInstance(spec.exec_properties['proto'], str)
     self.assertEqual(spec.exec_properties['proto'], proto_str)
Пример #5
0
  def testValidUnionChannel(self):
    channel1 = channel.Channel(type=_MyType)
    channel2 = channel.Channel(type=_MyType)
    union_channel = channel.union([channel1, channel2])
    self.assertIs(union_channel.type_name, 'MyTypeName')
    self.assertEqual(union_channel.channels, [channel1, channel2])

    union_channel = channel.union([channel1, channel.union([channel2])])
    self.assertIs(union_channel.type_name, 'MyTypeName')
    self.assertEqual(union_channel.channels, [channel1, channel2])
Пример #6
0
    def testBuildLatestBlessedModelResolverSucceed(self):
        latest_blessed_resolver = resolver.Resolver(
            instance_name='my_resolver2',
            strategy_class=latest_blessed_model_resolver.
            LatestBlessedModelResolver,
            model=channel.Channel(type=standard_artifacts.Model),
            model_blessing=channel.Channel(
                type=standard_artifacts.ModelBlessing))
        test_pipeline_info = data_types.PipelineInfo(
            pipeline_name='test-pipeline',
            pipeline_root='gs://path/to/my/root')

        deployment_config = pipeline_pb2.PipelineDeploymentConfig()
        component_defs = {}
        my_builder = step_builder.StepBuilder(
            node=latest_blessed_resolver,
            deployment_config=deployment_config,
            pipeline_info=test_pipeline_info,
            component_defs=component_defs)
        actual_step_specs = my_builder.build()

        model_blessing_resolver_id = 'Resolver.my_resolver2-model-blessing-resolver'
        model_resolver_id = 'Resolver.my_resolver2-model-resolver'
        self.assertSameElements(
            actual_step_specs.keys(),
            [model_blessing_resolver_id, model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_1.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_1.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_blessing_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_component_2.pbtxt',
                pipeline_pb2.ComponentSpec()),
            component_defs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_task_2.pbtxt',
                pipeline_pb2.PipelineTaskSpec()),
            actual_step_specs[model_resolver_id])

        self.assertProtoEquals(
            test_utils.get_proto_from_test_data(
                'expected_latest_blessed_model_resolver_executor.pbtxt',
                pipeline_pb2.PipelineDeploymentConfig()), deployment_config)
Пример #7
0
    def testComponentSpec_Basic(self):
        proto = example_gen_pb2.Input()
        proto.splits.extend([
            example_gen_pb2.Input.Split(name='name1', pattern='pattern1'),
            example_gen_pb2.Input.Split(name='name2', pattern='pattern2'),
            example_gen_pb2.Input.Split(name='name3', pattern='pattern3'),
        ])
        input_channel = channel.Channel(type=_InputArtifact)
        output_channel = channel.Channel(type=_OutputArtifact)
        spec = _BasicComponentSpec(folds=10,
                                   proto=proto,
                                   input=input_channel,
                                   output=output_channel)
        # Verify proto property.
        self.assertIsInstance(spec.exec_properties['proto'], str)
        decoded_proto = json.loads(spec.exec_properties['proto'])
        self.assertCountEqual(['splits'], decoded_proto.keys())
        self.assertLen(decoded_proto['splits'], 3)
        self.assertCountEqual(['name1', 'name2', 'name3'],
                              list(s['name'] for s in decoded_proto['splits']))
        self.assertCountEqual(['pattern1', 'pattern2', 'pattern3'],
                              list(s['pattern']
                                   for s in decoded_proto['splits']))

        # Verify other properties.
        self.assertEqual(10, spec.exec_properties['folds'])
        self.assertIs(spec.inputs['input'], input_channel)
        self.assertIs(spec.outputs['output'], output_channel)

        with self.assertRaisesRegex(
                TypeError,
                "Expected type <(class|type) 'int'> for parameter u?'folds' but got "
                'string.'):
            spec = _BasicComponentSpec(folds='string',
                                       input=input_channel,
                                       output=output_channel)

        with self.assertRaisesRegex(
                TypeError,
                '.*should be a Channel of .*InputArtifact.*got (.|\\s)*Examples.*'
        ):
            spec = _BasicComponentSpec(folds=10,
                                       input=channel.Channel(type=Examples),
                                       output=output_channel)

        with self.assertRaisesRegex(
                TypeError,
                '.*should be a Channel of .*OutputArtifact.*got (.|\\s)*Examples.*'
        ):
            spec = _BasicComponentSpec(folds=10,
                                       input=input_channel,
                                       output=channel.Channel(type=Examples))
Пример #8
0
    def testGetInidividualChannels(self):
        instance_a = _MyArtifact()
        instance_b = _MyArtifact()
        one_channel = channel.Channel(_MyArtifact).set_artifacts([instance_a])
        another_channel = channel.Channel(_MyArtifact).set_artifacts(
            [instance_b])

        result = channel_utils.get_individual_channels(one_channel)
        self.assertEqual(result, [one_channel])

        result = channel_utils.get_individual_channels(
            channel.union([one_channel, another_channel]))
        self.assertEqual(result, [one_channel, another_channel])
Пример #9
0
    def testExecutionParameterUseProto(self):
        class SpecWithNonPrimitiveTypes(ComponentSpec):
            PARAMETERS = {
                'config_proto':
                ExecutionParameter(type=example_gen_pb2.Input, use_proto=True),
                'boolean':
                ExecutionParameter(type=bool, use_proto=True),
                'list_config_proto':
                ExecutionParameter(type=List[example_gen_pb2.Input],
                                   use_proto=True),
                'list_boolean':
                ExecutionParameter(type=List[bool], use_proto=True),
            }
            INPUTS = {
                'input': ChannelParameter(type=_InputArtifact),
            }
            OUTPUTS = {
                'output': ChannelParameter(type=_OutputArtifact),
            }

        spec = SpecWithNonPrimitiveTypes(
            config_proto='{"splits": [{"name": "name", "pattern": "pattern"}]}',
            boolean=True,
            list_config_proto=[
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='trainer',
                                                pattern='train.data')
                ]),
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='eval',
                                                pattern='*eval.data')
                ])
            ],
            list_boolean=[False, True],
            input=channel.Channel(type=_InputArtifact),
            output=channel.Channel(type=_OutputArtifact))

        # Verify exec_properties store parsed value when use_proto set to True.
        expected_proto = text_format.Parse(
            """
            splits {
              name: "name"
              pattern: "pattern"
            }
          """, example_gen_pb2.Input())
        self.assertProtoEquals(expected_proto,
                               spec.exec_properties['config_proto'])
        self.assertEqual(True, spec.exec_properties['boolean'])
        self.assertIsInstance(spec.exec_properties['list_config_proto'], list)
        self.assertEqual(spec.exec_properties['list_boolean'], [False, True])
Пример #10
0
    def testComponentSpec_WithUnionChannel(self):
        input_channel_1 = channel.Channel(type=_InputArtifact)
        input_channel_2 = channel.Channel(type=_InputArtifact)
        output_channel = channel.Channel(type=_OutputArtifact)
        spec = _BasicComponentSpec(folds=10,
                                   input=channel.union(
                                       [input_channel_1, input_channel_2]),
                                   output=output_channel)

        # Verify properties.
        self.assertEqual(10, spec.exec_properties['folds'])
        self.assertEqual(spec.inputs['input'].type, _InputArtifact)
        self.assertEqual(spec.inputs['input'].channels,
                         [input_channel_1, input_channel_2])
        self.assertIs(spec.outputs['output'], output_channel)
Пример #11
0
 def testFutureProducesPlaceholder(self):
   chnl = channel.Channel(type=_MyType)
   future = chnl.future()
   self.assertIsInstance(future, placeholder.ChannelWrappedPlaceholder)
   self.assertIs(future.channel, chnl)
   self.assertIsInstance(future[0], placeholder.ChannelWrappedPlaceholder)
   self.assertIsInstance(future.value, placeholder.ChannelWrappedPlaceholder)
Пример #12
0
    def testComponentSpec_MissingArguments(self):
        class SimpleComponentSpec(ComponentSpec):
            PARAMETERS = {
                'x': ExecutionParameter(type=int),
                'y': ExecutionParameter(type=int, optional=True),
            }
            INPUTS = {'z': ChannelParameter(type=_Z)}
            OUTPUTS = {}

        with self.assertRaisesRegex(ValueError, 'Missing argument'):
            SimpleComponentSpec(x=10)

        with self.assertRaisesRegex(ValueError, 'Missing argument'):
            SimpleComponentSpec(z=channel.Channel(type=_Z))

        # Okay since y is optional.
        SimpleComponentSpec(x=10, z=channel.Channel(type=_Z))
Пример #13
0
    def testBuildOutputArtifactSpec(self):
        examples = standard_artifacts.Examples()
        examples.span = 1
        examples.set_int_custom_property(key='int_param', value=42)
        examples.set_string_custom_property(key='str_param', value='42')
        example_channel = channel.Channel(
            type=standard_artifacts.Examples).set_artifacts([examples])
        spec = compiler_utils.build_output_artifact_spec(example_channel)
        expected_spec = text_format.Parse(
            """
        artifact_type {
          instance_schema: "title: tfx.Examples\\ntype: object\\nproperties:\\n  span:\\n    type: integer\\n    description: Span for an artifact.\\n  version:\\n    type: integer\\n    description: Version for an artifact.\\n  split_names:\\n    type: string\\n    description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\\n"
        }
        metadata {
          fields {
            key: "int_param"
            value {
              number_value: 42.0
            }
          }
          fields {
            key: "span"
            value {
              number_value: 1.0
            }
          }
          fields {
            key: "str_param"
            value {
              string_value: "42"
            }
          }
        }
        """, pipeline_pb2.ComponentOutputsSpec.ArtifactSpec())
        self.assertProtoEquals(spec, expected_spec)

        # Empty output channel with only type info.
        model_channel = channel.Channel(type=standard_artifacts.Model)
        spec = compiler_utils.build_output_artifact_spec(model_channel)
        expected_spec = text_format.Parse(
            """
        artifact_type {
          instance_schema: "title: tfx.Model\\ntype: object\\n"
        }
        """, pipeline_pb2.ComponentOutputsSpec.ArtifactSpec())
        self.assertProtoEquals(spec, expected_spec)
Пример #14
0
    def testBuildInputArtifactSpec(self):
        spec = compiler_utils.build_input_artifact_spec(
            channel.Channel(type=standard_artifacts.Model))
        expected_spec = text_format.Parse(
            r'artifact_type { instance_schema: "title: tfx.Model\ntype: object\n" }',
            pipeline_pb2.ComponentInputsSpec.ArtifactSpec())
        self.assertProtoEquals(spec, expected_spec)

        # Test artifact type with properties
        spec = compiler_utils.build_input_artifact_spec(
            channel.Channel(type=standard_artifacts.Examples))
        expected_spec = text_format.Parse(
            """
        artifact_type {
          instance_schema: "title: tfx.Examples\\ntype: object\\nproperties:\\n  span:\\n    type: integer\\n    description: Span for an artifact.\\n  version:\\n    type: integer\\n    description: Version for an artifact.\\n  split_names:\\n    type: string\\n    description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\\n"
        }
        """, pipeline_pb2.ComponentInputsSpec.ArtifactSpec())
        self.assertProtoEquals(spec, expected_spec)
Пример #15
0
 def testUnwrapChannelDict(self):
     instance_a = _MyArtifact()
     instance_b = _MyArtifact()
     channel_dict = {
         'id':
         channel.Channel(_MyArtifact).set_artifacts(
             [instance_a, instance_b])
     }
     result = channel_utils.unwrap_channel_dict(channel_dict)
     self.assertDictEqual(result, {'id': [instance_a, instance_b]})
Пример #16
0
  def testJsonRoundTripUnknownArtifactClass(self):
    chnl = channel.Channel(type=_MyType)

    serialized = chnl.to_json_dict()
    serialized['type']['name'] = 'UnknownTypeName'

    rehydrated = channel.Channel.from_json_dict(serialized)
    self.assertEqual('UnknownTypeName', rehydrated.type_name)
    self.assertEqual(chnl.type._get_artifact_type().properties,
                     rehydrated.type._get_artifact_type().properties)
    self.assertTrue(rehydrated.type._AUTOGENERATED)
Пример #17
0
    def testOptionalOutputs(self):
        class SpecWithOptionalOutput(ComponentSpec):
            PARAMETERS = {}
            INPUTS = {}
            OUTPUTS = {'x': ChannelParameter(type=_Z, optional=True)}

        optional_not_specified = SpecWithOptionalOutput()
        self.assertNotIn('x', optional_not_specified.outputs.keys())
        self.assertTrue(optional_not_specified.is_optional_output('x'))
        optional_specified = SpecWithOptionalOutput(x=channel.Channel(type=_Z))
        self.assertIn('x', optional_specified.outputs.keys())
Пример #18
0
 def testCompileDynamicExecPropTypeError(self):
   dsl_compiler = compiler.Compiler()
   test_pipeline = self._get_test_pipeline_definition(
       dynamic_exec_properties_pipeline)
   downstream_component = next(
       c for c in test_pipeline.components
       if isinstance(c, dynamic_exec_properties_pipeline.DownstreamComponent))
   instance_a = _MyType()
   instance_b = _MyType()
   test_wrong_type_channel = channel.Channel(_MyType).set_artifacts(
       [instance_a, instance_b]).future()
   downstream_component.exec_properties["input_num"] = test_wrong_type_channel
   with self.assertRaisesRegex(
       ValueError,
       "output channel to dynamic exec properties is not ValueArtifact"):
     dsl_compiler.compile(test_pipeline)
Пример #19
0
 def testJsonRoundTrip(self):
   chnl = channel.Channel(
       type=_MyType,
       additional_properties={
           'string_value': metadata_store_pb2.Value(string_value='forty-two')
       },
       additional_custom_properties={
           'int_value': metadata_store_pb2.Value(int_value=42)
       })
   serialized = chnl.to_json_dict()
   rehydrated = channel.Channel.from_json_dict(serialized)
   self.assertIs(chnl.type, rehydrated.type)
   self.assertEqual(chnl.type_name, rehydrated.type_name)
   self.assertEqual(chnl.additional_properties,
                    rehydrated.additional_properties)
   self.assertEqual(chnl.additional_custom_properties,
                    rehydrated.additional_custom_properties)
Пример #20
0
def as_channel(artifacts: Iterable[artifact.Artifact]) -> channel.Channel:
    """Converts artifact collection of the same artifact type into a Channel.

  Args:
    artifacts: An iterable of Artifact.

  Returns:
    A static Channel containing the source artifact collection.

  Raises:
    ValueError when source is not a non-empty iterable of Artifact.
  """
    try:
        first_element = next(iter(artifacts))
        if isinstance(first_element, artifact.Artifact):
            return channel.Channel(
                type=first_element.type).set_artifacts(artifacts)
        else:
            raise ValueError('Invalid artifact iterable: {}'.format(artifacts))
    except StopIteration:
        raise ValueError('Cannot convert empty artifact iterable into Channel')
Пример #21
0
def _make_channel_dict(
        artifact_dict: Dict[Text, Text]) -> Dict[Text, channel.Channel]:
    """Makes a dictionary of artifact channels from a dictionary of artifacts.

  Args:
    artifact_dict: Dictionary of artifacts.

  Returns:
    Dictionary of artifact channels.

  Raises:
    RuntimeError: If list of artifacts is malformed.
  """
    channel_dict = {}
    for name, artifact_list in artifact_dict.items():
        if not artifact_list:
            raise RuntimeError(
                'Found empty list of artifacts for input/output named {}: {}'.
                format(name, artifact_list))
        type_name = artifact_list[0].type_name
        channel_dict[name] = channel.Channel(type_name=type_name,
                                             artifacts=artifact_list)

    return channel_dict
Пример #22
0
 def testInvalidChannelType(self):
   instance_a = _MyType()
   instance_b = _MyType()
   with self.assertRaises(ValueError):
     channel.Channel(_AnotherType).set_artifacts([instance_a, instance_b])
Пример #23
0
 def setUp(self):
     super().setUp()
     self._test_channel = channel.Channel(type=_MyArtifactWithProperty)
Пример #24
0
    def testExecutionParameterTypeCheck(self):
        int_parameter = ExecutionParameter(type=int)
        int_parameter.type_check('int_parameter', 8)
        with self.assertRaisesRegex(
                TypeError, "Expected type <(class|type) 'int'>"
                " for parameter u?'int_parameter'"):
            int_parameter.type_check('int_parameter', 'string')

        list_parameter = ExecutionParameter(type=List[int])
        list_parameter.type_check('list_parameter', [])
        list_parameter.type_check('list_parameter', [42])
        with self.assertRaisesRegex(TypeError,
                                    'Expecting a list for parameter'):
            list_parameter.type_check('list_parameter', 42)

        with self.assertRaisesRegex(
                TypeError, "Expecting item type <(class|type) "
                "'int'> for parameter u?'list_parameter'"):
            list_parameter.type_check('list_parameter', [42, 'wrong item'])

        dict_parameter = ExecutionParameter(type=Dict[str, int])
        dict_parameter.type_check('dict_parameter', {})
        dict_parameter.type_check('dict_parameter', {'key1': 1, 'key2': 2})
        with self.assertRaisesRegex(TypeError,
                                    'Expecting a dict for parameter'):
            dict_parameter.type_check('dict_parameter', 'simple string')

        with self.assertRaisesRegex(
                TypeError, "Expecting value type "
                "<(class|type) 'int'>"):
            dict_parameter.type_check('dict_parameter', {'key1': '1'})

        proto_parameter = ExecutionParameter(type=example_gen_pb2.Input)
        proto_parameter.type_check('proto_parameter', example_gen_pb2.Input())
        proto_parameter.type_check(
            'proto_parameter',
            proto_utils.proto_to_json(example_gen_pb2.Input()))
        proto_parameter.type_check('proto_parameter',
                                   {'splits': [{
                                       'name': 'hello'
                                   }]})
        proto_parameter.type_check('proto_parameter', {'wrong_field': 42})
        with self.assertRaisesRegex(
                TypeError,
                "Expected type <class 'tfx.proto.example_gen_pb2.Input'>"):
            proto_parameter.type_check('proto_parameter', 42)
        with self.assertRaises(json_format.ParseError):
            proto_parameter.type_check('proto_parameter', {'splits': 42})

        output_channel = channel.Channel(type=_OutputArtifact)

        placeholder_parameter = ExecutionParameter(type=str)
        placeholder_parameter.type_check(
            'wrapped_channel_placeholder_parameter',
            output_channel.future()[0].value)
        placeholder_parameter.type_check(
            'placeholder_parameter',
            placeholder.runtime_info('platform_config').base_dir)
        with self.assertRaisesRegex(
                TypeError,
                'Only simple RuntimeInfoPlaceholders are supported'):
            placeholder_parameter.type_check(
                'placeholder_parameter',
                placeholder.runtime_info('platform_config').base_dir +
                placeholder.exec_property('version'))
Пример #25
0
def create_pipeline_components(
    pipeline_root: Text,
    transform_module: Text,
    trainer_module: Text,
    bigquery_query: Text = '',
    csv_input_location: Text = '',
) -> List[base_node.BaseNode]:
    """Creates components for a simple Chicago Taxi TFX pipeline for testing.

  Args:
    pipeline_root: The root of the pipeline output.
    transform_module: The location of the transform module file.
    trainer_module: The location of the trainer module file.
    bigquery_query: The query to get input data from BigQuery. If not empty,
      BigQueryExampleGen will be used.
    csv_input_location: The location of the input data directory.

  Returns:
    A list of TFX components that constitutes an end-to-end test pipeline.
  """

    if bool(bigquery_query) == bool(csv_input_location):
        raise ValueError(
            'Exactly one example gen is expected. ',
            'Please provide either bigquery_query or csv_input_location.')

    if bigquery_query:
        example_gen = big_query_example_gen_component.BigQueryExampleGen(
            query=bigquery_query)
    else:
        example_gen = components.CsvExampleGen(input_base=csv_input_location)

    statistics_gen = components.StatisticsGen(
        examples=example_gen.outputs['examples'])
    schema_gen = components.SchemaGen(
        statistics=statistics_gen.outputs['statistics'],
        infer_feature_shape=False)
    example_validator = components.ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])
    transform = components.Transform(examples=example_gen.outputs['examples'],
                                     schema=schema_gen.outputs['schema'],
                                     module_file=transform_module)
    latest_model_resolver = resolver.Resolver(
        strategy_class=latest_artifacts_resolver.LatestArtifactsResolver,
        model=channel.Channel(type=standard_artifacts.Model)).with_id(
            'Resolver.latest_model_resolver')
    trainer = components.Trainer(
        custom_executor_spec=executor_spec.ExecutorClassSpec(Executor),
        transformed_examples=transform.outputs['transformed_examples'],
        schema=schema_gen.outputs['schema'],
        base_model=latest_model_resolver.outputs['model'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10),
        eval_args=trainer_pb2.EvalArgs(num_steps=5),
        module_file=trainer_module,
    )
    # Get the latest blessed model for model validation.
    model_resolver = resolver.Resolver(
        strategy_class=latest_blessed_model_resolver.
        LatestBlessedModelResolver,
        model=channel.Channel(type=standard_artifacts.Model),
        model_blessing=channel.Channel(
            type=standard_artifacts.ModelBlessing)).with_id(
                'Resolver.latest_blessed_model_resolver')
    # Set the TFMA config for Model Evaluation and Validation.
    eval_config = tfma.EvalConfig(
        model_specs=[tfma.ModelSpec(signature_name='eval')],
        metrics_specs=[
            tfma.MetricsSpec(
                metrics=[tfma.MetricConfig(class_name='ExampleCount')],
                thresholds={
                    'binary_accuracy':
                    tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.5}),
                        change_threshold=tfma.GenericChangeThreshold(
                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                            absolute={'value': -1e-10}))
                })
        ],
        slicing_specs=[
            tfma.SlicingSpec(),
            tfma.SlicingSpec(feature_keys=['trip_start_hour'])
        ])
    evaluator = components.Evaluator(
        examples=example_gen.outputs['examples'],
        model=trainer.outputs['model'],
        baseline_model=model_resolver.outputs['model'],
        eval_config=eval_config)

    pusher = components.Pusher(
        model=trainer.outputs['model'],
        model_blessing=evaluator.outputs['blessing'],
        push_destination=pusher_pb2.PushDestination(
            filesystem=pusher_pb2.PushDestination.Filesystem(
                base_directory=os.path.join(pipeline_root, 'model_serving'))))

    return [
        example_gen, statistics_gen, schema_gen, example_validator, transform,
        latest_model_resolver, trainer, model_resolver, evaluator, pusher
    ]
Пример #26
0
 def testEncodeTemplatedExecutorContainerSpec(self):
   specs = executor_specs.TemplatedExecutorContainerSpec(
       image='image',
       command=[
           self._text, self._input_value_placeholder,
           self._another_input_value_placeholder, self._input_uri_placeholder,
           self._output_uri_placeholder, self._concat_placeholder
       ])
   encode_result = specs.encode(
       component_spec=TestComponentSpec(
           input_artifact=channel.Channel(type=standard_artifacts.Examples),
           output_artifact=channel.Channel(type=standard_artifacts.Model),
           input_parameter=42))
   self.assertProtoEquals(
       """
     image: "image"
     commands {
       value {
         string_value: "text"
       }
     }
     commands {
       operator {
         artifact_value_op {
           expression {
             operator {
               index_op {
                 expression {
                   placeholder {
                     key: "input_artifact"
                   }
                 }
               }
             }
           }
         }
       }
     }
     commands {
       placeholder {
         type: EXEC_PROPERTY
         key: "input_parameter"
       }
     }
     commands {
       operator {
         artifact_uri_op {
           expression {
             operator {
               index_op {
                 expression {
                   placeholder {
                     key: "input_artifact"
                   }
                 }
                 index: 0
               }
             }
           }
         }
       }
     }
     commands {
       operator {
         artifact_uri_op {
           expression {
             operator {
               index_op {
                 expression {
                   placeholder {
                     type: OUTPUT_ARTIFACT
                     key: "output_artifact"
                   }
                 }
                 index: 0
               }
             }
           }
         }
       }
     }
     commands {
       operator {
         concat_op {
           expressions {
             value {
               string_value: "text"
             }
           }
           expressions {
             operator {
               artifact_value_op {
                 expression {
                   operator {
                     index_op {
                       expression {
                         placeholder {
                           key: "input_artifact"
                         }
                       }
                       index: 0
                     }
                   }
                 }
               }
             }
           }
           expressions {
             operator {
               artifact_uri_op {
                 expression {
                   operator {
                     index_op {
                       expression {
                         placeholder {
                           key: "input_artifact"
                         }
                       }
                       index: 0
                     }
                   }
                 }
               }
             }
           }
           expressions {
             operator {
               artifact_uri_op {
                 expression {
                   operator {
                     index_op {
                       expression {
                         placeholder {
                           type: OUTPUT_ARTIFACT
                           key: "output_artifact"
                         }
                       }
                       index: 0
                     }
                   }
                 }
               }
             }
           }
         }
       }
     }""", encode_result)
Пример #27
0
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
                     module_file: Text, serving_model_dir: Text,
                     metadata_path: Text,
                     direct_num_workers: int) -> pipeline.Pipeline:

    input_data = external_input(examples_path)

    input_config = example_gen_pb2.Input(splits=[
        example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
        example_gen_pb2.Input.Split(name='eval', pattern='eval.tfrecord')
    ])

    example_gen = ImportExampleGen(input=input_data, input_config=input_config)

    identify_examples = IdentifyExamples(
        orig_examples=example_gen.outputs['examples'],
        component_name=u'IdentifyExamples',
        id_feature_name=u'id')

    # Computes statistics over data for visualization and example validation.
    statistics_gen = StatisticsGen(
        examples=identify_examples.outputs["identified_examples"])

    schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'])

    # Performs anomaly detection based on statistics and data schema.
    validate_stats = ExampleValidator(
        statistics=statistics_gen.outputs['statistics'],
        schema=schema_gen.outputs['schema'])

    synthesize_graph = SynthesizeGraph(
        identified_examples=identify_examples.outputs['identified_examples'],
        component_name=u'SynthesizeGraph',
        similarity_threshold=0.99)

    transform = Transform(
        examples=identify_examples.outputs['identified_examples'],
        schema=schema_gen.outputs['schema'],
        # TODO(b/169218106): Remove transformed_examples kwargs after bugfix is released.
        transformed_examples=channel.Channel(
            type=standard_artifacts.Examples,
            artifacts=[standard_artifacts.Examples()]),
        module_file=_transform_module_file)

    # Augments training data with graph neighbors.
    graph_augmentation = GraphAugmentation(
        identified_examples=transform.outputs['transformed_examples'],
        synthesized_graph=synthesize_graph.outputs['synthesized_graph'],
        component_name=u'GraphAugmentation',
        num_neighbors=3)

    trainer = Trainer(
        module_file=_trainer_module_file,
        transformed_examples=graph_augmentation.outputs['augmented_examples'],
        schema=schema_gen.outputs['schema'],
        transform_graph=transform.outputs['transform_graph'],
        train_args=trainer_pb2.TrainArgs(num_steps=10000),
        eval_args=trainer_pb2.EvalArgs(num_steps=5000))

    model_validator = ModelValidator(examples=example_gen.outputs['examples'],
                                     model=trainer.outputs['model'])

    pusher = Pusher(model=trainer.outputs['model'],
                    model_blessing=model_validator.outputs['blessing'],
                    push_destination=pusher_pb2.PushDestination(
                        filesystem=pusher_pb2.PushDestination.Filesystem(
                            base_directory=serving_model_dir)))

    return pipeline.Pipeline(
        pipeline_name=pipeline_name,
        pipeline_root=pipeline_root,
        components=[
            example_gen, identify_examples, statistics_gen, schema_gen,
            validate_stats, synthesize_graph, transform, graph_augmentation,
            trainer, model_validator, pusher
        ],
        enable_cache=True,
        metadata_connection_config=metadata.sqlite_metadata_connection_config(
            metadata_path),
        beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])
Пример #28
0
 def testMismatchedUnionChannelType(self):
   chnl = channel.Channel(type=_MyType)
   another_channel = channel.Channel(type=_AnotherType)
   with self.assertRaises(TypeError):
     channel.union([chnl, another_channel])
Пример #29
0
 def testStringTypeNameNotAllowed(self):
   with self.assertRaises(ValueError):
     channel.Channel('StringTypeName')
Пример #30
0
 def testValidChannel(self):
   instance_a = _MyType()
   instance_b = _MyType()
   chnl = channel.Channel(_MyType).set_artifacts([instance_a, instance_b])
   self.assertEqual(chnl.type_name, 'MyTypeName')
   self.assertCountEqual(chnl.get(), [instance_a, instance_b])