Пример #1
0
 def _get_value_of_string_artifact(
         self, string_artifact: metadata_store_pb2.Artifact) -> Text:
     """Helper function returns the actual value of a ValueArtifact."""
     file_path = os.path.join(string_artifact.uri,
                              standard_artifacts.StringType.VALUE_FILE)
     # Assert there is a file exists.
     if (not tf.io.gfile.exists(file_path)) or tf.io.gfile.isdir(file_path):
         raise RuntimeError(
             'Given path does not exist or is not a valid file: %s' %
             file_path)
     serialized_value = tf.io.gfile.GFile(file_path, 'rb').read()
     return standard_artifacts.StringType().decode(serialized_value)
Пример #2
0
 def setUp(self):
   super(BaseDriverTest, self).setUp()
   self._mock_metadata = tf.compat.v1.test.mock.Mock()
   self._string_artifact = standard_artifacts.StringType()
   self._input_dict = {
       'input_data':
           types.Channel(
               type=_InputArtifact,
               artifacts=[_InputArtifact()],
               producer_component_id='c',
               output_key='k'),
       'input_string':
           types.Channel(
               type=standard_artifacts.StringType,
               artifacts=[self._string_artifact],
               producer_component_id='c2',
               output_key='k2')
   }
   input_dir = os.path.join(
       os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
       self._testMethodName, 'input_dir')
   # valid input artifacts must have a uri pointing to an existing directory.
   for key, input_channel in self._input_dict.items():
     for index, artifact in enumerate(input_channel.get()):
       artifact.id = index + 1
       uri = os.path.join(input_dir, key, str(artifact.id))
       artifact.uri = uri
       tf.io.gfile.makedirs(uri)
   self._output_dict = {
       'output_data':
           types.Channel(type=_OutputArtifact, artifacts=[_OutputArtifact()])
   }
   self._input_artifacts = channel_utils.unwrap_channel_dict(self._input_dict)
   self._output_artifacts = channel_utils.unwrap_channel_dict(
       self._output_dict)
   self._exec_properties = {
       'key': 'value',
   }
   self._execution_id = 100
   self._execution = metadata_store_pb2.Execution()
   self._execution.id = self._execution_id
   self._context_id = 123
   self._driver_args = data_types.DriverArgs(enable_cache=True)
   self._pipeline_info = data_types.PipelineInfo(
       pipeline_name='my_pipeline_name',
       pipeline_root=os.environ.get('TEST_TMP_DIR', self.get_temp_dir()),
       run_id='my_run_id')
   self._component_info = data_types.ComponentInfo(
       component_type='a.b.c',
       component_id='my_component_id',
       pipeline_info=self._pipeline_info)
Пример #3
0
 def __init__(self, word, greeting=None):
     if not greeting:
         artifact = standard_artifacts.StringType()
         greeting = channel_utils.as_channel([artifact])
     super(HelloWorldComponent,
           self).__init__(_HelloWorldSpec(word=word, greeting=greeting))
Пример #4
0
# limitations under the License.
"""Tests for tfx.scripts.entrypoint_utils."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
from typing import Any, Dict, Text
import tensorflow as tf

from tfx.scripts import ai_platform_entrypoint_utils
from tfx.types import standard_artifacts

_ARTIFACT_1 = standard_artifacts.StringType()
_KEY_1 = 'input_1'

_ARTIFACT_2 = standard_artifacts.ModelBlessing()
_KEY_2 = 'input_2'

_EXEC_PROPERTIES = {
    'input_config':
    'input config string',
    'output_config':
    '{ \"split_config\": { \"splits\": [ { \"hash_buckets\": 2, \"name\": '
    '\"train\" }, { \"hash_buckets\": 1, \"name\": \"eval\" } ] } }',
}


class EntrypointUtilsTest(tf.test.TestCase):
Пример #5
0
 def testStringType(self):
     instance = standard_artifacts.StringType()
     self.assertEqual(_TEST_STRING_RAW,
                      instance.encode(_TEST_STRING_DECODED))
     self.assertEqual(_TEST_STRING_DECODED,
                      instance.decode(_TEST_STRING_RAW))