Ejemplo n.º 1
0
 def test_component_basic(self):
     input_channel = channel.Channel(type_name='InputType')
     component = _BasicComponent(folds=10, input=input_channel)
     self.assertIs(input_channel, component.inputs.input)
     self.assertIsInstance(component.outputs.output, channel.Channel)
     self.assertEqual(component.outputs.output.type_name, 'OutputType')
Ejemplo n.º 2
0
 def test_type_check_success(self):
     chnl = channel.Channel('MyTypeName')
     chnl.type_check('MyTypeName')
Ejemplo n.º 3
0
 def test_type_check_fail(self):
     chnl = channel.Channel('MyTypeName')
     with self.assertRaises(TypeError):
         chnl.type_check('AnotherTypeName')
Ejemplo n.º 4
0
    def __init__(
            self,
            examples: channel.Channel = None,
            transformed_examples: channel.Channel = None,
            transform_output: Optional[channel.Channel] = None,
            schema: channel.Channel = None,
            module_file: Text = None,
            train_args: trainer_pb2.TrainArgs = None,
            eval_args: trainer_pb2.EvalArgs = None,
            custom_config: Optional[Dict[Text, Any]] = None,
            executor_class: Optional[Type[base_executor.BaseExecutor]] = None,
            output: Optional[channel.Channel] = None,
            name: Optional[Text] = None):
        """Construct a Trainer component.

    Args:
      examples: A Channel of 'ExamplesPath' type, serving as the source of
        examples that are used in training.
      transformed_examples: Deprecated field. Please set 'examples' instead.
      transform_output: An optional Channel of 'TransformPath' type, serving as
        the input transform graph if present.
      schema:  A Channel of 'SchemaPath' type, serving as the schema of training
        and eval data.
      module_file: A python module file containing UDF model definition.
      train_args: A trainer_pb2.TrainArgs instance, containing args used for
        training. Current only num_steps is available.
      eval_args: A trainer_pb2.EvalArgs instance, containing args used for eval.
        Current only num_steps is available.
      custom_config: A dict which contains the training job parameters to be
        passed to Google Cloud ML Engine.  For the full set of parameters
        supported by Google Cloud ML Engine, refer to
        https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#Job
      executor_class: Optional custom executor class.
      output: Optional 'ModelExportPath' channel for result of exported models.
      name: Optional unique name. Necessary iff multiple Trainer components are
        declared in the same pipeline.
    """
        output = output or channel.Channel(
            type_name='ModelExportPath',
            artifacts=[types.Artifact('ModelExportPath')])
        assert bool(examples) ^ bool(
            transformed_examples
        ), 'Exactly one of example or transformed_example should be set.'
        if transformed_examples:
            assert bool(
                transform_output
            ), 'If transformed_examples is set, transform_output should be set too.'
        examples = examples or transformed_examples
        transform_output_channel = channel.as_channel(
            transform_output) if transform_output else None
        spec = TrainerSpec(examples=channel.as_channel(examples),
                           transform_output=transform_output_channel,
                           schema=channel.as_channel(schema),
                           train_args=train_args,
                           eval_args=eval_args,
                           module_file=module_file,
                           custom_config=custom_config,
                           output=output)
        super(Trainer, self).__init__(spec=spec,
                                      custom_executor_class=executor_class,
                                      name=name)
Ejemplo n.º 5
0
 def test_invalid_channel_type(self):
     instance_a = types.Artifact('MyTypeName')
     instance_b = types.Artifact('MyTypeName')
     with self.assertRaises(ValueError):
         channel.Channel('AnotherTypeName',
                         artifacts=[instance_a, instance_b])
Ejemplo n.º 6
0
 def __init__(self, name: Text, spec_kwargs: Dict[Text, Any]):
     spec = _FakeComponentSpec(output=channel.Channel(type_name=name),
                               **spec_kwargs)
     super(_FakeComponent, self).__init__(spec=spec, name=name)
    def __init__(self, stats: str, schema: str):
        component = example_validator_component.ExampleValidator(
            channel.Channel('ExampleStatisticsPath'),
            channel.Channel('SchemaPath'))

        super().__init__(component, {"stats": stats, "schema": schema})
 def __init__(self, stats: str):
     component = schema_gen_component.SchemaGen(
         channel.Channel('ExampleStatisticsPath'))
     super().__init__(component, {"stats": stats})
 def __init__(self, input_data: str):
     component = statistics_gen_component.StatisticsGen(
         channel.Channel('ExamplesPath'))
     super().__init__(component, {"input_data": input_data})