def testResolveInputArtifacts(self): artifact_1 = standard_artifacts.String() artifact_1.id = 1 channel_1 = types.Channel(type=standard_artifacts.String, producer_component_id='c1').set_artifacts( [artifact_1]) artifact_2 = standard_artifacts.String() artifact_2.id = 2 channel_2 = types.Channel(type=standard_artifacts.String, producer_component_id='c2').set_artifacts( [artifact_2]) channel_3 = types.Channel(type=standard_artifacts.String, producer_component_id='c3').set_artifacts( [standard_artifacts.String()]) input_dict = { 'input_union': channel.union([channel_1, channel_2]), 'input_string': channel_3, } self._mock_metadata.search_artifacts.side_effect = [ channel_3.get(), channel_1.get(), channel_2.get() ] driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata) resolved_artifacts = driver.resolve_input_artifacts( input_dict=input_dict, exec_properties=self._exec_properties, driver_args=self._driver_args, pipeline_info=self._pipeline_info) self.assertEqual(len(resolved_artifacts['input_union']), 2) self.assertEqual(resolved_artifacts['input_union'][0].value, _STRING_VALUE) self.assertEqual(len(resolved_artifacts['input_string']), 1) self.assertEqual(resolved_artifacts['input_string'][0].value, _STRING_VALUE)
def setUp(self): super(BaseDriverTest, self).setUp() self._mock_metadata = tf.compat.v1.test.mock.Mock() self._input_dict = { 'input_data': types.Channel( type=_InputArtifact, artifacts=[_InputArtifact()], producer_component_id='c', output_key='k'), 'input_string': types.Channel( type=standard_artifacts.String, artifacts=[ standard_artifacts.String(), standard_artifacts.String() ], producer_component_id='c2', output_key='k2'), } input_dir = os.path.join( os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), self._testMethodName, 'input_dir') # valid input artifacts must have a uri pointing to an existing directory. for key, input_channel in self._input_dict.items(): for index, artifact in enumerate(input_channel.get()): artifact.id = index + 1 uri = os.path.join(input_dir, key, str(artifact.id)) artifact.uri = uri tf.io.gfile.makedirs(uri) self._output_dict = { 'output_data': types.Channel(type=_OutputArtifact, artifacts=[_OutputArtifact()]), 'output_multi_data': types.Channel( type=_OutputArtifact, matching_channel_name='input_string') } self._input_artifacts = channel_utils.unwrap_channel_dict(self._input_dict) self._output_artifacts = channel_utils.unwrap_channel_dict( self._output_dict) self._exec_properties = { 'key': 'value', } self._execution_id = 100 self._execution = metadata_store_pb2.Execution() self._execution.id = self._execution_id self._context_id = 123 self._driver_args = data_types.DriverArgs(enable_cache=True) self._pipeline_info = data_types.PipelineInfo( pipeline_name='my_pipeline_name', pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()), run_id='my_run_id') self._component_info = data_types.ComponentInfo( component_type='a.b.c', component_id='my_component_id', pipeline_info=self._pipeline_info)
def _get_value_of_string_artifact( self, string_artifact: metadata_store_pb2.Artifact) -> Text: """Helper function returns the actual value of a ValueArtifact.""" string_artifact_obj = standard_artifacts.String() string_artifact_obj.uri = string_artifact.uri string_artifact_obj.read() return string_artifact_obj.value
def _get_value_of_string_artifact( self, string_artifact: metadata_store_pb2.Artifact) -> Text: """Helper function returns the actual value of a ValueArtifact.""" file_path = os.path.join(string_artifact.uri, standard_artifacts.String.VALUE_FILE) # Assert there is a file exists. if (not tf.io.gfile.exists(file_path)) or tf.io.gfile.isdir(file_path): raise RuntimeError( 'Given path does not exist or is not a valid file: %s' % file_path) serialized_value = tf.io.gfile.GFile(file_path, 'rb').read() return standard_artifacts.String().decode(serialized_value)
def testStringType(self): instance = standard_artifacts.String() self.assertEqual(_TEST_STRING_RAW, instance.encode(_TEST_STRING_DECODED)) self.assertEqual(_TEST_STRING_DECODED, instance.decode(_TEST_STRING_RAW))
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for kubeflow_v2_entrypoint_utils.py.""" import os from kfp.pipeline_spec import pipeline_spec_pb2 as pipeline_pb2 import tensorflow as tf from tfx.components.evaluator import constants from tfx.orchestration.kubeflow.v2.container import kubeflow_v2_entrypoint_utils from tfx.types import standard_artifacts from tfx.utils import io_utils _ARTIFACT_1 = standard_artifacts.String() _KEY_1 = 'input_1' _ARTIFACT_2 = standard_artifacts.ModelBlessing() _KEY_2 = 'input_2' _ARTIFACT_3 = standard_artifacts.Examples() _KEY_3 = 'input_3' _EXEC_PROPERTIES = { 'input_config': 'input config string', 'output_config': '{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": ' '\"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }', }
def __init__(self, word, greeting=None): if not greeting: artifact = standard_artifacts.String() greeting = channel_utils.as_channel([artifact]) super(HelloWorldComponent, self).__init__(_HelloWorldSpec(word=word, greeting=greeting))