def testResolveInputArtifacts(self): artifact_1 = standard_artifacts.String() artifact_1.id = 1 channel_1 = types.Channel(type=standard_artifacts.String, producer_component_id='c1').set_artifacts( [artifact_1]) artifact_2 = standard_artifacts.String() artifact_2.id = 2 channel_2 = types.Channel(type=standard_artifacts.String, producer_component_id='c2').set_artifacts( [artifact_2]) channel_3 = types.Channel(type=standard_artifacts.String, producer_component_id='c3').set_artifacts( [standard_artifacts.String()]) input_dict = { 'input_union': channel.union([channel_1, channel_2]), 'input_string': channel_3, } self._mock_metadata.search_artifacts.side_effect = [ channel_3.get(), channel_1.get(), channel_2.get() ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) resolved_artifacts = driver.resolve_input_artifacts( input_dict=input_dict, exec_properties=self._exec_properties, driver_args=self._driver_args, pipeline_info=self._pipeline_info) self.assertEqual(len(resolved_artifacts['input_union']), 2) self.assertEqual(resolved_artifacts['input_union'][0].value, _STRING_VALUE) self.assertEqual(len(resolved_artifacts['input_string']), 1) self.assertEqual(resolved_artifacts['input_string'][0].value, _STRING_VALUE)
def testPreExecutionCached(self, mock_verify_input_artifacts_fn): self._mock_metadata.search_artifacts.return_value = list( self._input_dict['input_string'].get()) self._mock_metadata.register_run_context_if_not_exists.side_effect = [ metadata_store_pb2.Context() ] self._mock_metadata.register_execution.side_effect = [self._execution] self._mock_metadata.get_cached_outputs.side_effect = [ self._output_artifacts ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=self._input_dict, output_dict=self._output_dict, exec_properties=self._exec_properties, driver_args=self._driver_args, pipeline_info=self._pipeline_info, component_info=self._component_info) self.assertTrue(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, self._execution_id) self.assertCountEqual(execution_decision.exec_properties, self._exec_properties) self.assertCountEqual(execution_decision.output_dict, self._output_artifacts)
def testPreExecutionNewExecution(self, mock_verify_input_artifacts_fn): self._mock_metadata.search_artifacts.return_value = list( self._input_dict['input_string'].get()) self._mock_metadata.register_execution.side_effect = [self._execution] self._mock_metadata.get_cached_outputs.side_effect = [None] self._mock_metadata.register_run_context_if_not_exists.side_effect = [ metadata_store_pb2.Context() ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) execution_decision = driver.pre_execution( input_dict=self._input_dict, output_dict=self._output_dict, exec_properties=self._exec_properties, driver_args=self._driver_args, pipeline_info=self._pipeline_info, component_info=self._component_info) self.assertFalse(execution_decision.use_cached_results) self.assertEqual(execution_decision.execution_id, self._execution_id) self.assertCountEqual(execution_decision.exec_properties, self._exec_properties) self.assertEqual( execution_decision.output_dict['output_data'][0].uri, os.path.join(self._pipeline_info.pipeline_root, self._component_info.component_id, 'output_data', str(self._execution_id))) self.assertLen(execution_decision.output_dict['output_multi_data'], 2) for i in range(2): self.assertEqual( execution_decision.output_dict['output_multi_data'][i].uri, os.path.join(self._pipeline_info.pipeline_root, self._component_info.component_id, 'output_multi_data', str(self._execution_id), str(i))) self.assertEqual(execution_decision.input_dict['input_string'][0].value, _STRING_VALUE)
def testVerifyInputArtifactsNotExists(self): driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) with self.assertRaises(RuntimeError): driver.verify_input_artifacts({'artifact': [_InputArtifact()]})
def testVerifyInputArtifactsOk(self): driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) driver.verify_input_artifacts(self._input_artifacts)