Esempio n. 1
0
    def testResolveInputArtifacts(self):
        artifact_1 = standard_artifacts.String()
        artifact_1.id = 1
        channel_1 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c1').set_artifacts(
                                      [artifact_1])
        artifact_2 = standard_artifacts.String()
        artifact_2.id = 2
        channel_2 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c2').set_artifacts(
                                      [artifact_2])
        channel_3 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c3').set_artifacts(
                                      [standard_artifacts.String()])
        input_dict = {
            'input_union': channel.union([channel_1, channel_2]),
            'input_string': channel_3,
        }
        self._mock_metadata.search_artifacts.side_effect = [
            channel_3.get(), channel_1.get(),
            channel_2.get()
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        resolved_artifacts = driver.resolve_input_artifacts(
            input_dict=input_dict,
            exec_properties=self._exec_properties,
            driver_args=self._driver_args,
            pipeline_info=self._pipeline_info)
        self.assertEqual(len(resolved_artifacts['input_union']), 2)
        self.assertEqual(resolved_artifacts['input_union'][0].value,
                         _STRING_VALUE)
        self.assertEqual(len(resolved_artifacts['input_string']), 1)
        self.assertEqual(resolved_artifacts['input_string'][0].value,
                         _STRING_VALUE)
Esempio n. 2
0
    def testPreExecutionCached(self, mock_verify_input_artifacts_fn):
        self._mock_metadata.search_artifacts.return_value = list(
            self._input_dict['input_string'].get())
        self._mock_metadata.register_run_context_if_not_exists.side_effect = [
            metadata_store_pb2.Context()
        ]
        self._mock_metadata.register_execution.side_effect = [self._execution]
        self._mock_metadata.get_cached_outputs.side_effect = [
            self._output_artifacts
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        execution_decision = driver.pre_execution(
            input_dict=self._input_dict,
            output_dict=self._output_dict,
            exec_properties=self._exec_properties,
            driver_args=self._driver_args,
            pipeline_info=self._pipeline_info,
            component_info=self._component_info)
        self.assertTrue(execution_decision.use_cached_results)
        self.assertEqual(execution_decision.execution_id, self._execution_id)
        self.assertCountEqual(execution_decision.exec_properties,
                              self._exec_properties)
        self.assertCountEqual(execution_decision.output_dict,
                              self._output_artifacts)
Esempio n. 3
0
  def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn):
    self._mock_metadata.search_artifacts.return_value = list(
        self._input_dict['input_string'].get())
    self._mock_metadata.register_execution.side_effect = [self._execution]
    self._mock_metadata.get_cached_outputs.side_effect = [None]
    self._mock_metadata.register_run_context_if_not_exists.side_effect = [
        metadata_store_pb2.Context()
    ]

    driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
    execution_decision = driver.pre_execution(
        input_dict=self._input_dict,
        output_dict=self._output_dict,
        exec_properties=self._exec_properties,
        driver_args=self._driver_args,
        pipeline_info=self._pipeline_info,
        component_info=self._component_info)
    self.assertFalse(execution_decision.use_cached_results)
    self.assertEqual(execution_decision.execution_id, self._execution_id)
    self.assertCountEqual(execution_decision.exec_properties,
                          self._exec_properties)
    self.assertEqual(
        execution_decision.output_dict['output_data'][0].uri,
        os.path.join(self._pipeline_info.pipeline_root,
                     self._component_info.component_id, 'output_data',
                     str(self._execution_id)))
    self.assertLen(execution_decision.output_dict['output_multi_data'], 2)
    for i in range(2):
      self.assertEqual(
          execution_decision.output_dict['output_multi_data'][i].uri,
          os.path.join(self._pipeline_info.pipeline_root,
                       self._component_info.component_id, 'output_multi_data',
                       str(self._execution_id), str(i)))
    self.assertEqual(execution_decision.input_dict['input_string'][0].value,
                     _STRING_VALUE)
Esempio n. 4
0
 def testVerifyInputArtifactsNotExists(self):
     driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
     with self.assertRaises(RuntimeError):
         driver.verify_input_artifacts({'artifact': [_InputArtifact()]})
Esempio n. 5
0
 def testVerifyInputArtifactsOk(self):
     driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
     driver.verify_input_artifacts(self._input_artifacts)