Example #1
0
def compare_file_sizes(output_uri: Text, expected_uri: Text,
                       threshold: float) -> bool:
    """Compares pipeline output files sizes in output and recorded uri.

  Args:
    output_uri: pipeline output artifact uri.
    expected_uri: recorded pipeline output artifact uri.
    threshold: a float between 0 and 1.

  Returns:
     boolean whether file sizes differ within a threshold.
  """
    for dir_name, sub_dirs, leaf_files in fileio.walk(expected_uri):
        for sub_dir in sub_dirs:
            new_file_path = os.path.join(
                dir_name.replace(expected_uri, output_uri, 1), sub_dir)
            if not fileio.exists(new_file_path):
                return False
        for leaf_file in leaf_files:
            expected_file_name = os.path.join(dir_name, leaf_file)
            file_name = os.path.join(
                dir_name.replace(expected_uri, output_uri, 1), leaf_file)
            if not _compare_relative_difference(
                    fileio.open(file_name).size(),
                    fileio.open(expected_file_name).size(), threshold):
                return False
    return True
Example #2
0
def compare_model_file_sizes(output_uri: Text, expected_uri: Text,
                             threshold: float) -> bool:
    """Compares pipeline output files sizes in output and recorded uri.

  Args:
    output_uri: pipeline output artifact uri.
    expected_uri: recorded pipeline output artifact uri.
    threshold: a float between 0 and 1.

  Returns:
     boolean whether file sizes differ within a threshold.
  """
    for dir_name, sub_dirs, leaf_files in fileio.walk(expected_uri):
        if ('Format-TFMA' in dir_name or 'eval_model_dir' in dir_name
                or 'export' in dir_name):
            continue
        for sub_dir in sub_dirs:
            new_file_path = os.path.join(
                dir_name.replace(expected_uri, output_uri, 1), sub_dir)
            if not fileio.exists(new_file_path):
                return False
        for leaf_file in leaf_files:
            if leaf_file.startswith('events.out.tfevents'):
                continue
            expected_file_name = os.path.join(dir_name, leaf_file)
            file_name = os.path.join(
                dir_name.replace(expected_uri, output_uri, 1), leaf_file)
            if not _compare_relative_difference(
                    fileio.open(file_name).size(),
                    fileio.open(expected_file_name).size(), threshold):
                return False
    return True
Example #3
0
    def setUp(self):
        super().setUp()
        self._test_dir = tempfile.mkdtemp()

        self._executor_invocation = pipeline_pb2.ExecutorInput()
        self._executor_invocation.outputs.output_file = _TEST_OUTPUT_METADATA_JSON
        self._executor_invocation.inputs.parameters[
            'input_base_uri'].string_value = _TEST_INPUT_DIR
        self._executor_invocation.inputs.parameters[
            'input_config'].string_value = json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1',
                                                pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(name='s2',
                                                pattern='span{SPAN}/split2/*')
                ]))
        self._executor_invocation.outputs.artifacts[
            'examples'].artifacts.append(
                pipeline_pb2.RuntimeArtifact(
                    type=pipeline_pb2.ArtifactTypeSchema(
                        instance_schema=compiler_utils.get_artifact_schema(
                            standard_artifacts.Examples()))))

        self._executor_invocation_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'executor_invocation.json'), 'r').read()
        self._expected_result_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'expected_output_metadata.json'), 'r').read()

        self._olddir = os.getcwd()
        os.chdir(self._test_dir)
        fileio.makedirs(os.path.dirname(_TEST_OUTPUT_METADATA_JSON))
        fileio.makedirs(os.path.dirname(_TEST_INPUT_DIR))
    def testInvokeTFLiteRewriterNoAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with fileio.open(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw', filename='fname')
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(saved_model_path=mock.ANY,
                                          enable_quantization=False)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')
Example #5
0
    def run(self,
            pipeline: tfx_pipeline.Pipeline,
            parameter_values: Optional[Dict[Text, Any]] = None,
            write_out: Optional[bool] = True) -> Dict[Text, Any]:
        """Compiles a pipeline DSL object into pipeline file.

    Args:
      pipeline: TFX pipeline object.
      parameter_values: mapping from runtime parameter names to its values.
      write_out: set to True to actually write out the file to the place
        designated by output_dir and output_filename. Otherwise return the
        JSON-serialized pipeline job spec.

    Returns:
      Returns the JSON pipeline job spec.

    Raises:
      RuntimeError: if trying to write out to a place occupied by an existing
      file.
    """
        # TODO(b/166343606): Support user-provided labels.
        # TODO(b/169095387): Deprecate .run() method in favor of the unified API
        # client.
        display_name = (self._config.display_name
                        or pipeline.pipeline_info.pipeline_name)
        pipeline_spec = pipeline_builder.PipelineBuilder(
            tfx_pipeline=pipeline,
            default_image=self._config.default_image,
            default_commands=self._config.default_commands).build()
        pipeline_spec.sdk_version = version.__version__
        pipeline_spec.schema_version = _SCHEMA_VERSION
        runtime_config = pipeline_builder.RuntimeConfigBuilder(
            pipeline_info=pipeline.pipeline_info,
            parameter_values=parameter_values).build()
        with telemetry_utils.scoped_labels(
            {telemetry_utils.LABEL_TFX_RUNNER: 'kubeflow_v2'}):
            result = pipeline_pb2.PipelineJob(
                display_name=display_name
                or pipeline.pipeline_info.pipeline_name,
                labels=telemetry_utils.get_labels_dict(),
                runtime_config=runtime_config)
        result.pipeline_spec.update(json_format.MessageToDict(pipeline_spec))
        pipeline_json_dict = json_format.MessageToDict(result)
        if write_out:
            if fileio.exists(
                    self._output_dir) and not fileio.isdir(self._output_dir):
                raise RuntimeError('Output path: %s is pointed to a file.' %
                                   self._output_dir)
            if not fileio.exists(self._output_dir):
                fileio.makedirs(self._output_dir)

            fileio.open(os.path.join(self._output_dir, self._output_filename),
                        'wb').write(
                            json.dumps(pipeline_json_dict, sort_keys=True))

        return pipeline_json_dict
Example #6
0
  def testDo(self, mock_client):
    # Mock query result schema for _BigQueryConverter.
    mock_client.return_value.query.return_value.result.return_value.schema = self._schema

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {'examples': [examples]}

    # Create exe properties.
    exec_properties = {
        'input_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, b, f, s FROM `fake`'),
                ])),
        'output_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])))
    }

    # Run executor.
    big_query_example_gen = executor.Executor(
        base_beam_executor.BaseBeamExecutor.Context(
            beam_pipeline_args=['--project=test-project']))
    big_query_example_gen.Do({}, output_dict, exec_properties)

    mock_client.assert_called_with(project='test-project')

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check BigQuery example gen outputs.
    train_output_file = os.path.join(examples.uri, 'Split-train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
Example #7
0
  def testDo(self):
    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    examples = standard_artifacts.Examples()
    examples.uri = output_data_dir
    output_dict = {utils.EXAMPLES_KEY: [examples]}

    # Create exec proterties.
    exec_properties = {
        utils.INPUT_BASE_KEY:
            self._input_data_dir,
        utils.INPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='avro', pattern='avro/*.avro'),
                ]),
                preserving_proto_field_name=True),
        utils.OUTPUT_CONFIG_KEY:
            json_format.MessageToJson(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(
                            name='train', hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(
                            name='eval', hash_buckets=1)
                    ])),
                preserving_proto_field_name=True)
    }

    # Run executor.
    avro_example_gen = avro_executor.Executor()
    avro_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        examples.split_names)

    # Check Avro example gen outputs.
    train_output_file = os.path.join(examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
Example #8
0
def _export_fn(estimator, export_path, checkpoint_path, eval_result,
               is_the_final_export):
    del estimator, checkpoint_path, eval_result, is_the_final_export
    path = os.path.join(export_path, BASE_EXPORT_SUBDIR)
    fileio.makedirs(path)
    with fileio.open(os.path.join(path, ORIGINAL_SAVED_MODEL), 'w') as f:
        f.write(str(ORIGINAL_SAVED_MODEL))

    assets_path = os.path.join(path, tf.saved_model.ASSETS_DIRECTORY)
    fileio.makedirs(assets_path)
    with fileio.open(os.path.join(assets_path, ORIGINAL_VOCAB), 'w') as f:
        f.write(str(ORIGINAL_VOCAB))

    return path
Example #9
0
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model, dst_model, src_model_path, dst_model_path = (
            self.create_temp_model_template())

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        fileio.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with fileio.open(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        fileio.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with fileio.open(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        tfrw = tflite_rewriter.TFLiteRewriter(
            name='myrw',
            filename='fname',
            quantization_optimizations=[tf.lite.Optimize.DEFAULT])
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(
            saved_model_path=mock.ANY,
            quantization_optimizations=[tf.lite.Optimize.DEFAULT],
            quantization_supported_types=[],
            representative_dataset=None,
            signature_key=None)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with fileio.open(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with fileio.open(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
Example #10
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]}

        # Create exec proterties.
        exec_properties = {
            standard_component_specs.INPUT_BASE_KEY:
            self._input_data_dir,
            standard_component_specs.INPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='parquet',
                                                pattern='parquet/*'),
                ])),
            standard_component_specs.OUTPUT_CONFIG_KEY:
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ])))
        }

        # Run executor.
        parquet_example_gen = parquet_executor.Executor()
        parquet_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Parquet example gen outputs.
        train_output_file = os.path.join(examples.uri, 'Split-train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
Example #11
0
    def testDo(self):
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create output dict.
        examples = standard_artifacts.Examples()
        examples.uri = output_data_dir
        output_dict = {'examples': [examples]}

        # Create exe properties.
        exec_properties = {
            'input_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(
                        name='bq', pattern='SELECT i, f, s FROM `fake`'),
                ])),
            'custom_config':
            proto_utils.proto_to_json(example_gen_pb2.CustomConfig()),
            'output_config':
            proto_utils.proto_to_json(
                example_gen_pb2.Output(
                    split_config=example_gen_pb2.SplitConfig(splits=[
                        example_gen_pb2.SplitConfig.Split(name='train',
                                                          hash_buckets=2),
                        example_gen_pb2.SplitConfig.Split(name='eval',
                                                          hash_buckets=1)
                    ]))),
        }

        # Run executor.
        presto_example_gen = executor.Executor()
        presto_example_gen.Do({}, output_dict, exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         examples.split_names)

        # Check Presto example gen outputs.
        train_output_file = os.path.join(examples.uri, 'Split-train',
                                         'data_tfrecord-00000-of-00001.gz')
        eval_output_file = os.path.join(examples.uri, 'Split-eval',
                                        'data_tfrecord-00000-of-00001.gz')
        self.assertTrue(fileio.exists(train_output_file))
        self.assertTrue(fileio.exists(eval_output_file))
        self.assertGreater(
            fileio.open(train_output_file).size(),
            fileio.open(eval_output_file).size())
    def _testDo(self):
        # Run executor.
        example_gen = TestExampleGenExecutor()
        example_gen.Do({}, self._output_dict, self._exec_properties)

        self.assertEqual(artifact_utils.encode_split_names(['train', 'eval']),
                         self._examples.split_names)

        # Check example gen outputs.
        self.assertTrue(fileio.exists(self._train_output_file))
        self.assertTrue(fileio.exists(self._eval_output_file))

        # Output split ratio: train:eval=2:1.
        self.assertGreater(
            fileio.open(self._train_output_file).size(),
            fileio.open(self._eval_output_file).size())
Example #13
0
def run_fn(fn_args: TrainerFnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
  schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())

  x_train, y_train = _input_fn(fn_args.train_files, fn_args.data_accessor,
                               schema)
  x_eval, y_eval = _input_fn(fn_args.eval_files, fn_args.data_accessor, schema)

  steps_per_epoch = _TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE

  model = MLPClassifier(
      hidden_layer_sizes=[8, 8, 8],
      activation='relu',
      solver='adam',
      batch_size=_TRAIN_BATCH_SIZE,
      learning_rate_init=0.0005,
      max_iter=int(fn_args.train_steps / steps_per_epoch),
      verbose=True)
  model.fit(x_train, y_train)
  absl.logging.info(model)

  score = model.score(x_eval, y_eval)
  absl.logging.info('Accuracy: %f', score)

  os.makedirs(fn_args.serving_model_dir)

  model_path = os.path.join(fn_args.serving_model_dir, 'model.pkl')
  with fileio.open(model_path, 'wb+') as f:
    pickle.dump(model, f)
Example #14
0
  def testInvokeTFLiteRewriterQuantizationFullIntegerSucceeds(self, converter):
    m = self.ConverterMock()
    converter.return_value = m

    src_model, dst_model, _, dst_model_path = self.create_temp_model_template()

    def representative_dataset():
      for i in range(2):
        yield [np.array(i)]

    tfrw = tflite_rewriter.TFLiteRewriter(
        name='myrw',
        filename='fname',
        quantization_optimizations=[tf.lite.Optimize.DEFAULT],
        quantization_enable_full_integer=True,
        representative_dataset=representative_dataset)
    tfrw.perform_rewrite(src_model, dst_model)

    converter.assert_called_once_with(
        saved_model_path=mock.ANY,
        quantization_optimizations=[tf.lite.Optimize.DEFAULT],
        quantization_supported_types=[],
        representative_dataset=representative_dataset,
        signature_key=None)
    expected_model = os.path.join(dst_model_path, 'fname')
    self.assertTrue(fileio.exists(expected_model))
    with fileio.open(expected_model, 'rb') as f:
      self.assertEqual(f.read(), b'model')
Example #15
0
 def Do(self, input_dict: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, List[types.Artifact]],
        exec_properties: Dict[Text, Any]) -> None:
     executor_output = execution_result_pb2.ExecutorOutput()
     outputs_utils.populate_output_artifact(executor_output, output_dict)
     with fileio.open(self._context.executor_output_uri, 'wb') as f:
         f.write(executor_output.SerializeToString())
Example #16
0
def main(_):

    flags.mark_flag_as_required(EXECUTION_INVOCATION_FLAG.name)
    flags.mark_flag_as_required(EXECUTABLE_SPEC_FLAG.name)

    execution_info = python_execution_binary_utils.deserialize_execution_info(
        EXECUTION_INVOCATION_FLAG.value)
    python_class_executable_spec = (
        python_execution_binary_utils.deserialize_executable_spec(
            EXECUTABLE_SPEC_FLAG.value))
    logging.info('execution_info = %r\n', execution_info)
    logging.info('python_class_executable_spec = %s\n',
                 text_format.MessageToString(python_class_executable_spec))

    # MLMD connection config being set indicates a driver execution instead of an
    # executor execution as accessing MLMD is not supported for executors.
    if MLMD_CONNECTION_CONFIG_FLAG.value:
        mlmd_connection_config = (
            python_execution_binary_utils.deserialize_mlmd_connection_config(
                MLMD_CONNECTION_CONFIG_FLAG.value))
        run_result = _run_driver(python_class_executable_spec,
                                 mlmd_connection_config, execution_info)
    else:
        run_result = _run_executor(python_class_executable_spec,
                                   execution_info)

    if run_result:
        with fileio.open(execution_info.execution_output_uri, 'wb') as f:
            f.write(run_result.SerializeToString())
Example #17
0
    def testDriverJsonContract(self):
        # This test is identical to testDriverWithoutSpan, but uses raw JSON strings
        # for inputs and expects against the raw JSON output of the driver, to
        # better illustrate the JSON I/O contract of the driver.
        split1 = os.path.join(_TEST_INPUT_DIR, 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')
        os.utime(split1, (0, 1))
        split2 = os.path.join(_TEST_INPUT_DIR, 'split2', 'data')
        io_utils.write_string_file(split2, 'testing2')
        os.utime(split2, (0, 3))

        serialized_args = [
            '--json_serialized_invocation_args',
            self._executor_invocation_from_file
        ]

        # Invoke the driver
        driver.main(driver._parse_flags(serialized_args))

        # Check the output metadata file for the expected outputs
        with fileio.open(_TEST_OUTPUT_METADATA_JSON, 'rb') as output_meta_json:
            self.assertEqual(
                json.dumps(json.loads(output_meta_json.read()),
                           indent=2,
                           sort_keys=True),
                json.dumps(json.loads(self._expected_result_from_file),
                           indent=2,
                           sort_keys=True))
Example #18
0
 def testGetExecutorOutputUri(self):
   executor_output_uri = self._output_resolver.get_executor_output_uri(1)
   self.assertRegex(executor_output_uri,
                    '.*/test_node/executor_execution/1/executor_output.pb')
   # Verify that executor_output_uri is writable.
   with fileio.open(executor_output_uri, mode='w') as f:
     executor_output = execution_result_pb2.ExecutorOutput()
     f.write(executor_output.SerializeToString())
Example #19
0
    def testArtifactSchemaMapping(self):
        # Test first party standard artifact.
        example_schema = compiler_utils.get_artifact_schema(
            standard_artifacts.Examples)
        expected_example_schema = fileio.open(
            os.path.join(self._schema_base_dir, 'Examples.yaml'), 'rb').read()
        self.assertEqual(expected_example_schema, example_schema)

        # Test Kubeflow simple artifact.
        file_schema = compiler_utils.get_artifact_schema(simple_artifacts.File)
        expected_file_schema = fileio.open(
            os.path.join(self._schema_base_dir, 'File.yaml'), 'rb').read()
        self.assertEqual(expected_file_schema, file_schema)

        # Test custom artifact type.
        my_artifact_schema = compiler_utils.get_artifact_schema(_MyArtifact)
        self.assertDictEqual(yaml.safe_load(my_artifact_schema),
                             yaml.safe_load(_EXPECTED_MY_ARTIFACT_SCHEMA))
Example #20
0
 def write(self, value):
     if value is None:
         self.set_int_custom_property(_IS_NULL_KEY, 1)
         serialized_value = b''
     else:
         self.set_int_custom_property(_IS_NULL_KEY, 0)
         serialized_value = self.encode(value)
     with fileio.open(self.uri, 'wb') as f:
         f.write(serialized_value)
Example #21
0
def read_string_file(file_name: Text) -> Text:
    """Reads a string from a file."""
    if not fileio.exists(file_name):
        msg = '{} does not exist'.format(file_name)
        if six.PY2:
            raise OSError(msg)
        else:
            raise FileNotFoundError(msg)  # pylint: disable=undefined-variable
    return fileio.open(file_name).read()
Example #22
0
def _get_golden_subgraph(graph_name, spec):
  """Retrieves a corresponding golden subgraph."""
  filename = _generate_unique_filename(spec.input_names)
  filepath = os.path.join(
      os.path.dirname(__file__), 'testdata', graph_name, filename)

  graph_def = tf.compat.v1.GraphDef()
  with fileio.open(filepath, 'r') as f:
    text_format.Parse(f.read(), graph_def)
  return graph_def
Example #23
0
 def testLogging(self):
     """Ensure a logged string actually appears in the log file."""
     logger = logging_utils.get_logger(self._logger_config)
     logger.info('Test')
     log_file_path = os.path.join(self._log_root)
     f = fileio.open(os.path.join(log_file_path, 'tfx.log'), mode='r')
     self.assertRegex(
         f.read(),
         r'^\d\d\d\d-\d\d-\d\d \d\d:\d\d:\d\d,\d\d\d - : \(logging_utils_test.py:\d\d\) - INFO: Test$'
     )
Example #24
0
 def _rewrite(self, original_model, rewritten_model):
     self.rewrite_called = True
     assert fileio.exists(
         os.path.join(original_model.path, ORIGINAL_SAVED_MODEL))
     assert fileio.exists(
         os.path.join(original_model.path,
                      tf.saved_model.ASSETS_DIRECTORY, ORIGINAL_VOCAB))
     with fileio.open(
             os.path.join(rewritten_model.path, REWRITTEN_SAVED_MODEL),
             'w') as f:
         f.write(str(REWRITTEN_SAVED_MODEL))
     assets_path = os.path.join(rewritten_model.path,
                                tf.saved_model.ASSETS_DIRECTORY)
     fileio.makedirs(assets_path)
     with fileio.open(os.path.join(assets_path, REWRITTEN_VOCAB),
                      'w') as f:
         f.write(str(REWRITTEN_VOCAB))
     if self._rewrite_raises_error:
         raise ValueError('rewrite-error')
 def _PrintTaskLogsOnError(self, task):
     task_log_dir = os.path.join(self._airflow_home, 'logs',
                                 '%s.%s' % (self._dag_id, task))
     for dir_name, _, leaf_files in fileio.walk(task_log_dir):
         for leaf_file in leaf_files:
             leaf_file_path = os.path.join(dir_name, leaf_file)
             absl.logging.error('Print task log %s:', leaf_file_path)
             with fileio.open(leaf_file_path, 'r') as f:
                 lines = f.readlines()
                 for line in lines:
                     absl.logging.error(line)
Example #26
0
  def _testDo(self, payload_format):
    exec_properties = {
        utils.INPUT_BASE_KEY: self._input_data_dir,
        utils.INPUT_CONFIG_KEY: self._input_config,
        utils.OUTPUT_CONFIG_KEY: self._output_config,
        utils.OUTPUT_DATA_FORMAT_KEY: payload_format,
    }

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    self.examples = standard_artifacts.Examples()
    self.examples.uri = output_data_dir
    output_dict = {utils.EXAMPLES_KEY: [self.examples]}

    # Run executor.
    import_example_gen = executor.Executor()
    import_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        self.examples.split_names)

    # Check import_example_gen outputs.
    train_output_file = os.path.join(self.examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(self.examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')

    # Check import_example_gen outputs.
    train_output_file = os.path.join(self.examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(self.examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
Example #27
0
def get_artifact_schema(artifact_type: Type[artifact.Artifact]) -> str:
  """Gets the YAML schema string associated with the artifact type."""
  if artifact_type in _SUPPORTED_STANDARD_ARTIFACT_TYPES:
    # For supported first-party artifact types, get the built-in schema yaml per
    # its type name.
    schema_path = os.path.join(
        os.path.dirname(__file__), 'artifact_types',
        '{}.yaml'.format(artifact_type.TYPE_NAME))
    return fileio.open(schema_path, 'rb').read()
  else:
    # Otherwise, fall back to the generic `Artifact` type schema.
    # To recover the Python type object at runtime, the class import path will
    # be encoded as the schema title.

    # Read the generic artifact schema template.
    schema_path = os.path.join(
        os.path.dirname(__file__), 'artifact_types', 'Artifact.yaml')
    data = yaml.safe_load(fileio.open(schema_path, 'rb').read())
    # Encode class import path.
    data['title'] = f'{artifact_type.__module__}.{artifact_type.__name__}'
    return yaml.dump(data, sort_keys=False)
Example #28
0
  def read(self):
    if not self._has_value:
      file_path = self.uri
      # Assert there is a file exists.
      if not fileio.exists(file_path):
        raise RuntimeError(
            'Given path does not exist or is not a valid file: %s' % file_path)

      serialized_value = fileio.open(file_path, 'rb').read()
      self._has_value = True
      self._value = self.decode(serialized_value)
    return self._value
Example #29
0
    def setUp(self):
        super().setUp()

        self._executor_invocation = pipeline_pb2.ExecutorInput()
        self._executor_invocation.outputs.output_file = _TEST_OUTPUT_METADATA_JSON
        self._executor_invocation.inputs.parameters[
            'input_base'].string_value = _TEST_INPUT_DIR
        self._executor_invocation.inputs.parameters[
            'output_config'].string_value = '{}'
        self._executor_invocation.inputs.parameters[
            'input_config'].string_value = json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1',
                                                pattern='span{SPAN}/split1/*'),
                    example_gen_pb2.Input.Split(name='s2',
                                                pattern='span{SPAN}/split2/*')
                ]))
        self._executor_invocation.outputs.artifacts[
            'examples'].artifacts.append(
                pipeline_pb2.RuntimeArtifact(
                    type=pipeline_pb2.ArtifactTypeSchema(
                        instance_schema=compiler_utils.get_artifact_schema(
                            standard_artifacts.Examples))))

        self._executor_invocation_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'executor_invocation.json'), 'r').read()

        logging.debug('Executor invocation under test: %s',
                      self._executor_invocation_from_file)
        self._expected_result_from_file = fileio.open(
            os.path.join(os.path.dirname(__file__), 'testdata',
                         'expected_output_metadata.json'), 'r').read()
        logging.debug('Expecting output metadata JSON: %s',
                      self._expected_result_from_file)

        # Change working directory after all the testdata files have been read.
        self.enter_context(test_case_utils.change_working_dir(self.tmp_dir))

        fileio.makedirs(os.path.dirname(_TEST_INPUT_DIR))
Example #30
0
  def setUp(self):
    self._executor_invocation = pipeline_pb2.ExecutorInput()
    self._executor_invocation.outputs.output_file = _TEST_OUTPUT_METADATA_JSON
    self._executor_invocation.inputs.parameters[
        'input_base_uri'].string_value = _TEST_INPUT_DIR
    self._executor_invocation.inputs.parameters[
        'input_config'].string_value = json_format.MessageToJson(
            example_gen_pb2.Input(splits=[
                example_gen_pb2.Input.Split(
                    name='s1', pattern='span{SPAN}/split1/*'),
                example_gen_pb2.Input.Split(
                    name='s2', pattern='span{SPAN}/split2/*')
            ]))
    self._executor_invocation.outputs.artifacts['examples'].artifacts.append(
        pipeline_pb2.RuntimeArtifact(
            type=pipeline_pb2.ArtifactTypeSchema(
                instance_schema=compiler_utils.get_artifact_schema(
                    standard_artifacts.Examples()))))

    self._executor_invocation_from_file = fileio.open(
        os.path.join(
            os.path.dirname(__file__), 'testdata', 'executor_invocation.json'),
        'r').read()

    logging.debug('Executor invocation under test: %s',
                  self._executor_invocation_from_file)
    self._expected_result_from_file = fileio.open(
        os.path.join(
            os.path.dirname(__file__), 'testdata',
            'expected_output_metadata.json'), 'r').read()
    logging.debug('Expecting output metadata JSON: %s',
                  self._expected_result_from_file)

    # The initialization of TempWorkingDirTestCase has to be called after all
    # the testdata files have been read. Otherwise the original testdata files
    # are not accessible after cwd is changed.
    super().setUp()

    fileio.makedirs(os.path.dirname(_TEST_OUTPUT_METADATA_JSON))
    fileio.makedirs(os.path.dirname(_TEST_INPUT_DIR))