Ejemplo n.º 1
0
    def testRun_MultipleInputs_ExecutionFailed(self, mock_resolve):
        mock_resolve.return_value = inputs_utils.Trigger([
            {
                'model': [self._create_model_artifact(uri='/tmp/model/1')]
            },
            {
                'model': [self._create_model_artifact(uri='/tmp/model/2')]
            },
        ])
        handler = resolver_node_handler.ResolverNodeHandler()

        execution_info = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._my_resolver,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            self.assertTrue(execution_info.execution_id)
            [execution
             ] = m.store.get_executions_by_id([execution_info.execution_id])
            self.assertProtoPartiallyEquals("""
          id: 1
          last_known_state: FAILED
          """,
                                            execution,
                                            ignored_fields=[
                                                'type_id', 'custom_properties',
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
Ejemplo n.º 2
0
    def testRun_InputResolutionError_ExecutionFailed(self, mock_resolve):
        mock_resolve.side_effect = exceptions.InputResolutionError('Meh')
        handler = resolver_node_handler.ResolverNodeHandler()

        execution_info = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._my_resolver,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            self.assertTrue(execution_info.execution_id)
            [execution
             ] = m.store.get_executions_by_id([execution_info.execution_id])
            self.assertProtoPartiallyEquals("""
          id: 1
          last_known_state: FAILED
          """,
                                            execution,
                                            ignored_fields=[
                                                'type_id', 'custom_properties',
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])
Ejemplo n.º 3
0
    def testSuccess(self):
        with self._mlmd_connection as m:
            # Publishes two models which will be consumed by downstream resolver.
            output_model_1 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_1.uri = 'my_model_uri_1'

            output_model_2 = types.Artifact(
                self._my_trainer.outputs.outputs['model'].artifact_spec.type)
            output_model_2.uri = 'my_model_uri_2'

            contexts = context_lib.prepare_contexts(m,
                                                    self._my_trainer.contexts)
            execution = execution_publish_utils.register_execution(
                m, self._my_trainer.node_info.type, contexts)
            execution_publish_utils.publish_succeeded_execution(
                m, execution.id, contexts, {
                    'model': [output_model_1, output_model_2],
                })

        handler = resolver_node_handler.ResolverNodeHandler()
        execution_metadata = handler.run(
            mlmd_connection=self._mlmd_connection,
            pipeline_node=self._resolver_node,
            pipeline_info=self._pipeline_info,
            pipeline_runtime_spec=self._pipeline_runtime_spec)

        with self._mlmd_connection as m:
            # There is no way to directly verify the output artifact of the resolver
            # So here a fake downstream component is created which listens to the
            # resolver's output and we verify its input.
            down_stream_node = text_format.Parse(
                """
        inputs {
          inputs {
            key: "input_models"
            value {
              channels {
                producer_node_query {
                  id: "my_resolver"
                }
                context_queries {
                  type {
                    name: "pipeline"
                  }
                  name {
                    field_value {
                      string_value: "my_pipeline"
                    }
                  }
                }
                context_queries {
                  type {
                    name: "component"
                  }
                  name {
                    field_value {
                      string_value: "my_resolver"
                    }
                  }
                }
                artifact_query {
                  type {
                    name: "Model"
                  }
                }
                output_key: "models"
              }
              min_count: 1
            }
          }
        }
        upstream_nodes: "my_resolver"
        """, pipeline_pb2.PipelineNode())
            downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
                metadata_handler=m, node_inputs=down_stream_node.inputs)
            downstream_input_model = downstream_input_artifacts['input_models']
            self.assertLen(downstream_input_model, 1)
            self.assertProtoPartiallyEquals(
                """
          id: 2
          type_id: 5
          uri: "my_model_uri_2"
          state: LIVE""",
                downstream_input_model[0].mlmd_artifact,
                ignored_fields=[
                    'create_time_since_epoch', 'last_update_time_since_epoch'
                ])
            [execution] = m.store.get_executions_by_id([execution_metadata.id])

            self.assertProtoPartiallyEquals("""
          id: 2
          type_id: 6
          last_known_state: COMPLETE
          """,
                                            execution,
                                            ignored_fields=[
                                                'create_time_since_epoch',
                                                'last_update_time_since_epoch'
                                            ])