예제 #1
0
    def testResolveInputArtifacts(self):
        artifact_1 = standard_artifacts.String()
        artifact_1.id = 1
        channel_1 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c1').set_artifacts(
                                      [artifact_1])
        artifact_2 = standard_artifacts.String()
        artifact_2.id = 2
        channel_2 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c2').set_artifacts(
                                      [artifact_2])
        channel_3 = types.Channel(type=standard_artifacts.String,
                                  producer_component_id='c3').set_artifacts(
                                      [standard_artifacts.String()])
        input_dict = {
            'input_union': channel.union([channel_1, channel_2]),
            'input_string': channel_3,
        }
        self._mock_metadata.search_artifacts.side_effect = [
            channel_3.get(), channel_1.get(),
            channel_2.get()
        ]

        driver = base_driver.BaseDriver(metadata_handler=self._mock_metadata)
        resolved_artifacts = driver.resolve_input_artifacts(
            input_dict=input_dict,
            exec_properties=self._exec_properties,
            driver_args=self._driver_args,
            pipeline_info=self._pipeline_info)
        self.assertEqual(len(resolved_artifacts['input_union']), 2)
        self.assertEqual(resolved_artifacts['input_union'][0].value,
                         _STRING_VALUE)
        self.assertEqual(len(resolved_artifacts['input_string']), 1)
        self.assertEqual(resolved_artifacts['input_string'][0].value,
                         _STRING_VALUE)
예제 #2
0
 def setUp(self):
   super(BaseDriverTest, self).setUp()
   self._mock_metadata = tf.compat.v1.test.mock.Mock()
   self._input_dict = {
       'input_data':
           types.Channel(
               type=_InputArtifact,
               artifacts=[_InputArtifact()],
               producer_component_id='c',
               output_key='k'),
       'input_string':
           types.Channel(
               type=standard_artifacts.String,
               artifacts=[
                   standard_artifacts.String(),
                   standard_artifacts.String()
               ],
               producer_component_id='c2',
               output_key='k2'),
   }
   input_dir = os.path.join(
       os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
       self._testMethodName, 'input_dir')
   # valid input artifacts must have a uri pointing to an existing directory.
   for key, input_channel in self._input_dict.items():
     for index, artifact in enumerate(input_channel.get()):
       artifact.id = index + 1
       uri = os.path.join(input_dir, key, str(artifact.id))
       artifact.uri = uri
       tf.io.gfile.makedirs(uri)
   self._output_dict = {
       'output_data':
           types.Channel(type=_OutputArtifact, artifacts=[_OutputArtifact()]),
       'output_multi_data':
           types.Channel(
               type=_OutputArtifact, matching_channel_name='input_string')
   }
   self._input_artifacts = channel_utils.unwrap_channel_dict(self._input_dict)
   self._output_artifacts = channel_utils.unwrap_channel_dict(
       self._output_dict)
   self._exec_properties = {
       'key': 'value',
   }
   self._execution_id = 100
   self._execution = metadata_store_pb2.Execution()
   self._execution.id = self._execution_id
   self._context_id = 123
   self._driver_args = data_types.DriverArgs(enable_cache=True)
   self._pipeline_info = data_types.PipelineInfo(
       pipeline_name='my_pipeline_name',
       pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
       run_id='my_run_id')
   self._component_info = data_types.ComponentInfo(
       component_type='a.b.c',
       component_id='my_component_id',
       pipeline_info=self._pipeline_info)
예제 #3
0
    def _get_value_of_string_artifact(
            self, string_artifact: metadata_store_pb2.Artifact) -> Text:
        """Helper function returns the actual value of a ValueArtifact."""

        string_artifact_obj = standard_artifacts.String()
        string_artifact_obj.uri = string_artifact.uri
        string_artifact_obj.read()
        return string_artifact_obj.value
예제 #4
0
 def _get_value_of_string_artifact(
     self, string_artifact: metadata_store_pb2.Artifact) -> Text:
   """Helper function returns the actual value of a ValueArtifact."""
   file_path = os.path.join(string_artifact.uri,
                            standard_artifacts.String.VALUE_FILE)
   # Assert there is a file exists.
   if (not tf.io.gfile.exists(file_path)) or tf.io.gfile.isdir(file_path):
     raise RuntimeError(
         'Given path does not exist or is not a valid file: %s' % file_path)
   serialized_value = tf.io.gfile.GFile(file_path, 'rb').read()
   return standard_artifacts.String().decode(serialized_value)
예제 #5
0
 def testStringType(self):
     instance = standard_artifacts.String()
     self.assertEqual(_TEST_STRING_RAW,
                      instance.encode(_TEST_STRING_DECODED))
     self.assertEqual(_TEST_STRING_DECODED,
                      instance.decode(_TEST_STRING_RAW))
예제 #6
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kubeflow_v2_entrypoint_utils.py."""

import os
from kfp.pipeline_spec import pipeline_spec_pb2 as pipeline_pb2
import tensorflow as tf
from tfx.components.evaluator import constants
from tfx.orchestration.kubeflow.v2.container import kubeflow_v2_entrypoint_utils
from tfx.types import standard_artifacts
from tfx.utils import io_utils

_ARTIFACT_1 = standard_artifacts.String()
_KEY_1 = 'input_1'

_ARTIFACT_2 = standard_artifacts.ModelBlessing()
_KEY_2 = 'input_2'

_ARTIFACT_3 = standard_artifacts.Examples()
_KEY_3 = 'input_3'

_EXEC_PROPERTIES = {
    'input_config': 'input config string',
    'output_config':
        '{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": '
        '\"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }',
}
예제 #7
0
파일: test_utils.py 프로젝트: htahir1/tfx
 def __init__(self, word, greeting=None):
     if not greeting:
         artifact = standard_artifacts.String()
         greeting = channel_utils.as_channel([artifact])
     super(HelloWorldComponent,
           self).__init__(_HelloWorldSpec(word=word, greeting=greeting))