Пример #1
0
def run_fn(fn_args: executor.TrainerFnArgs):
    """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """
    schema = io_utils.parse_pbtxt_file(fn_args.schema_file,
                                       schema_pb2.Schema())

    training_spec = trainer_fn(fn_args, schema)

    # Train the model
    absl.logging.info('Training model.')
    tf.estimator.train_and_evaluate(training_spec['estimator'],
                                    training_spec['train_spec'],
                                    training_spec['eval_spec'])
    absl.logging.info('Training complete.  Model written to %s',
                      fn_args.serving_model_dir)

    # Export an eval savedmodel for TFMA
    # NOTE: When trained in distributed training cluster, eval_savedmodel must be
    # exported only by the chief worker.
    absl.logging.info('Exporting eval_savedmodel for TFMA.')
    tfma.export.export_eval_savedmodel(
        estimator=training_spec['estimator'],
        export_dir_base=fn_args.eval_model_dir,
        eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])

    # Simulate writing a log to the path given by fn_args
    io_utils.write_string_file(
        os.path.join(fn_args.model_run_dir, 'fake_log.txt'), '')

    absl.logging.info('Exported eval_savedmodel to %s.',
                      fn_args.eval_model_dir)
Пример #2
0
 def setUp(self):
     self._output_data_dir = os.path.join(
         os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
         self._testMethodName)
     self._job_dir = os.path.join(self._output_data_dir, 'jobDir')
     self._fake_package = os.path.join(self._output_data_dir,
                                       'fake_package')
     self._project_id = '12345'
     io_utils.write_string_file(self._fake_package, 'fake package content')
     self._mock_api_client = mock.Mock()
     self._inputs = {}
     self._outputs = {}
     self._training_inputs = {
         'project': self._project_id,
         'jobDir': self._job_dir,
     }
     self._exec_properties = {
         'custom_config': {
             'gaip_training_args': self._training_inputs
         },
     }
     self._cmle_serving_args = {
         'model_name': 'model_name',
         'project_id': self._project_id,
     }
Пример #3
0
    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
           output_dict: Dict[Text, List[types.Artifact]],
           exec_properties: Dict[Text, Any]) -> None:
        if exec_properties.get(_TUNE_ARGS_KEY):
            raise ValueError(
                "TuneArgs is not supported for default Tuner's Executor.")

        tuner_fn = udf_utils.get_fn(exec_properties, 'tuner_fn')
        fn_args = fn_args_utils.get_common_fn_args(input_dict, exec_properties,
                                                   self._get_tmp_dir())

        tuner_fn_result = tuner_fn(fn_args)
        tuner = tuner_fn_result.tuner
        fit_kwargs = tuner_fn_result.fit_kwargs

        # TODO(b/156966497): set logger for printing.
        tuner.search_space_summary()
        absl.logging.info('Start tuning...')
        tuner.search(**fit_kwargs)
        tuner.results_summary()
        best_hparams_config = tuner.get_best_hyperparameters()[0].get_config()
        absl.logging.info('Best hyperParameters: %s' % best_hparams_config)
        best_hparams_path = os.path.join(
            artifact_utils.get_single_uri(
                output_dict[_BEST_HYPERPARAMETERS_KEY]), _DEFAULT_FILE_NAME)
        io_utils.write_string_file(best_hparams_path,
                                   json.dumps(best_hparams_config))
        absl.logging.info('Best Hyperparameters are written to %s.' %
                          best_hparams_path)
Пример #4
0
    def testDriverJsonContract(self):
        # This test is identical to testDriverWithoutSpan, but uses raw JSON strings
        # for inputs and expects against the raw JSON output of the driver, to
        # better illustrate the JSON I/O contract of the driver.
        split1 = os.path.join(_TEST_INPUT_DIR, 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')
        os.utime(split1, (0, 1))
        split2 = os.path.join(_TEST_INPUT_DIR, 'split2', 'data')
        io_utils.write_string_file(split2, 'testing2')
        os.utime(split2, (0, 3))

        serialized_args = [
            'driver.py', '--json_serialized_invocation_args',
            self._executor_invocation_from_file
        ]

        # Invoke the driver
        driver.main(serialized_args)

        # Check the output metadata file for the expected outputs
        with open(_TEST_OUTPUT_METADATA_JSON) as output_meta_json:
            self.assertDictEqual(
                json.loads(re.sub(r'\s+', '', output_meta_json.read())),
                json.loads(re.sub(r'\s+', '',
                                  self._expected_result_from_file)))
Пример #5
0
def _mock_subprocess_call(cmd: Sequence[Optional[Text]],
                          env: Mapping[Text, Text]) -> int:
  """Mocks the subprocess call."""
  assert len(cmd) == 2, 'Unexpected number of commands: {}'.format(cmd)
  del env
  dsl_path = cmd[1]

  if dsl_path.endswith('test_pipeline_bad.py'):
    sys.exit(1)
  if not dsl_path.endswith(
      'test_pipeline_1.py') and not dsl_path.endswith(
          'test_pipeline_2.py'):
    raise ValueError('Unexpected dsl path: {}'.format(dsl_path))

  spec_pb = pipeline_pb2.PipelineSpec(
      pipeline_info=pipeline_pb2.PipelineInfo(name='chicago_taxi_kubeflow'))
  runtime_pb = pipeline_pb2.PipelineJob.RuntimeConfig(
      gcs_output_directory=os.path.join(os.environ['HOME'], 'tfx', 'pipelines',
                                        'chicago_taxi_kubeflow'))
  job_pb = pipeline_pb2.PipelineJob(runtime_config=runtime_pb)
  job_pb.pipeline_spec.update(json_format.MessageToDict(spec_pb))
  io_utils.write_string_file(
      file_name='pipeline.json',
      string_value=json_format.MessageToJson(message=job_pb, sort_keys=True))
  return 0
Пример #6
0
  def Do(self, input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    # KerasTuner generates tuning state (e.g., oracle, trials) to working dir.
    working_dir = self._get_tmp_dir()

    train_path = artifact_utils.get_split_uri(input_dict['examples'], 'train')
    eval_path = artifact_utils.get_split_uri(input_dict['examples'], 'eval')
    schema_file = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(input_dict['schema']))
    schema = io_utils.parse_pbtxt_file(schema_file, schema_pb2.Schema())

    tuner_fn = self._GetTunerFn(exec_properties)
    tuner_spec = tuner_fn(working_dir, io_utils.all_files_pattern(train_path),
                          io_utils.all_files_pattern(eval_path), schema)
    tuner = tuner_spec.tuner

    tuner.search_space_summary()
    # TODO(jyzhao): assert v2 behavior as KerasTuner doesn't work in v1.
    # TODO(jyzhao): make epochs configurable.
    tuner.search(
        tuner_spec.train_dataset,
        epochs=5,
        validation_data=tuner_spec.eval_dataset)
    tuner.results_summary()

    best_hparams = tuner.oracle.get_best_trials(
        1)[0].hyperparameters.get_config()
    best_hparams_path = os.path.join(
        artifact_utils.get_single_uri(output_dict['study_best_hparams_path']),
        _DEFAULT_FILE_NAME)
    io_utils.write_string_file(best_hparams_path, json.dumps(best_hparams))
    absl.logging.info('Best HParams is written to %s.' % best_hparams_path)
Пример #7
0
  def Do(self, input_dict: Dict[Text, List[types.TfxType]],
         output_dict: Dict[Text, List[types.TfxType]],
         exec_properties: Dict[Text, Any]) -> None:
    """Get human review result on a model through Slack channel.

    Args:
      input_dict: Input dict from input key to a list of artifacts, including:
        - model_export: exported model from trainer.
        - model_blessing: model blessing path from model_validator.
      output_dict: Output dict from key to a list of artifacts, including:
        - slack_blessing: model blessing result.
      exec_properties: A dict of execution properties, including:
        - slack_token: Token used to setup connection with slack server.
        - channel_id: The id of the Slack channel to send and receive messages.
        - timeout_sec: How long do we wait for response, in seconds.

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    # Fetch execution properties from exec_properties dict.
    slack_token = exec_properties['slack_token']
    channel_id = exec_properties['channel_id']
    timeout_sec = exec_properties['timeout_sec']

    # Fetch input URIs from input_dict.
    model_export_uri = types.get_single_uri(input_dict['model_export'])
    model_blessing_uri = types.get_single_uri(input_dict['model_blessing'])

    # Fetch output artifact from output_dict.
    slack_blessing = types.get_single_instance(output_dict['slack_blessing'])

    # We only consider a model as blessed if both of the following conditions
    # are met:
    # - The model is blessed by model validator. This is determined by looking
    #   for file named 'BLESSED' from the output from Model Validator.
    # - The model is blessed by a human reviewer. This logic is in
    #   _fetch_slack_blessing().
    try:
      with Timeout(timeout_sec):
        blessed = tf.gfile.Exists(os.path.join(
            model_blessing_uri, 'BLESSED')) and self._fetch_slack_blessing(
                slack_token, channel_id, model_export_uri)
    except TimeoutError:  # pylint: disable=undefined-variable
      tf.logging.info('Timeout fetching manual model evaluation result.')
      blessed = False

    # If model is blessed, write an empty file named 'BLESSED' in the assigned
    # output path. Otherwise, write an empty file named 'NOT_BLESSED' instead.
    if blessed:
      io_utils.write_string_file(
          os.path.join(slack_blessing.uri, 'BLESSED'), '')
      slack_blessing.set_int_custom_property('blessed', 1)
    else:
      io_utils.write_string_file(
          os.path.join(slack_blessing.uri, 'NOT_BLESSED'), '')
      slack_blessing.set_int_custom_property('blessed', 0)
    tf.logging.info('Blessing result %s written to %s.', blessed,
                    slack_blessing.uri)
Пример #8
0
 def testGetOnlyDirInDir(self):
   top_level_dir = os.path.join(self._base_dir, 'dir_1')
   dir_path = os.path.join(top_level_dir, 'dir_2')
   file_path = os.path.join(dir_path, 'file')
   io_utils.write_string_file(file_path, 'testing')
   self.assertEqual('dir_2', os.path.basename(
       io_utils.get_only_uri_in_dir(top_level_dir)))
Пример #9
0
 def testGetOnlyDirInDir(self):
   top_level_dir = os.path.join(self._base_dir, 'dir_1')
   dir_path = os.path.join(top_level_dir, 'dir_2')
   file_path = os.path.join(dir_path, 'file')
   io_utils.write_string_file(file_path, 'testing')
   self.assertEqual('dir_2', os.path.basename(
       io_utils.get_only_uri_in_dir(top_level_dir)))
Пример #10
0
    def testStubExecutor(self, mock_publisher):
        # verify whether base stub executor substitution works
        mock_publisher.return_value.publish_execution.return_value = {}

        record_file = os.path.join(self.record_dir, 'output', 'recorded.txt')
        io_utils.write_string_file(record_file, 'hello world')
        component_ids = ['_FakeComponent.FakeComponent']

        my_stub_launcher = \
            stub_component_launcher.get_stub_launcher_class(
                test_data_dir=self.record_dir,
                stubbed_component_ids=component_ids,
                stubbed_component_map={})

        launcher = my_stub_launcher.create(
            component=self.component,
            pipeline_info=self.pipeline_info,
            driver_args=self.driver_args,
            metadata_connection=self.metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        launcher.launch()

        output_path = self.component.outputs['output'].get()[0].uri
        copied_file = os.path.join(output_path, 'recorded.txt')
        self.assertTrue(tf.io.gfile.exists(copied_file))
        contents = io_utils.read_string_file(copied_file)
        self.assertEqual('hello world', contents)
Пример #11
0
  def Do(self,
         input_dict: Dict[Text, List[types.Artifact]],
         output_dict: Dict[Text, List[types.Artifact]],
         exec_properties: Dict[Text, Any]) -> None:
    """Stores `custom_config` as an artifact of type `artifacts.PipelineConfiguration`.

    Args:
      input_dict: Empty
      output_dict: Output dict from key to a list of artifacts, including:
        - pipeline_configuration: A list of type `artifacts.PipelineConfiguration`
      exec_properties: A dict of execution properties, including:
        - custom_config: the configuration to save.
    Returns:
      None

    Raises:
      OSError and its subclasses
      ValueError
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    pipeline_configuration = artifact_utils.get_single_instance(output_dict[PIPELINE_CONFIGURATION_KEY])
    custom_config = exec_properties.get(CUSTOM_CONFIG_KEY, "{}")

    output_dir = artifact_utils.get_single_uri([pipeline_configuration])
    output_file = os.path.join(output_dir, 'custom_config.json')

    io_utils.write_string_file(output_file, custom_config)
Пример #12
0
    def testPipelineCompile(self):
        # Invalid DSL path
        pipeline_path = os.path.join(self._testdata_dir,
                                     'test_pipeline_flink.py')
        result = self.runner.invoke(cli_group, [
            'pipeline', 'compile', '--engine', 'beam', '--pipeline_path',
            pipeline_path
        ])
        self.assertIn('CLI', result.output)
        self.assertIn('Compiling pipeline', result.output)
        self.assertIn('Invalid pipeline path: {}'.format(pipeline_path),
                      result.output)

        # Wrong Runner.
        pipeline_path = os.path.join(self.tmp_dir, 'empty_file.py')
        io_utils.write_string_file(pipeline_path, '')
        result = self.runner.invoke(cli_group, [
            'pipeline', 'compile', '--engine', 'beam', '--pipeline_path',
            pipeline_path
        ])
        self.assertIn('CLI', result.output)
        self.assertIn('Compiling pipeline', result.output)
        self.assertIn('Cannot find BeamDagRunner.run()', result.output)

        # Successful compilation.
        pipeline_path = os.path.join(self._testdata_dir,
                                     'test_pipeline_beam_2.py')
        result = self.runner.invoke(cli_group, [
            'pipeline', 'compile', '--engine', 'beam', '--pipeline_path',
            pipeline_path
        ])
        self.assertIn('CLI', result.output)
        self.assertIn('Compiling pipeline', result.output)
        self.assertIn('Pipeline compiled successfully', result.output)
Пример #13
0
    def Do(self, input_dict: Dict[Text, List[types.TfxArtifact]],
           output_dict: Dict[Text, List[types.TfxArtifact]],
           exec_properties: Dict[Text, Any]) -> None:
        """Get human review result on a model through Slack channel.

        Args:
          input_dict: Input dict from input key to a list of artifacts, including:
            - input_example: an example for an input

          output_dict: Output dict from key to a list of artifacts, including:
            - output_example: an example for an output
          exec_properties: A dict of execution properties, including:
            - string_parameter: An string execution parameter (only used in here, not persistent or shared up stream)
            - integer_parameter: An integer execution parameter (only used in here, not persistent or shared up stream)
            - input_config: not of concern here, only relevant for Driver
            - output_config: not of concern here, only relevant for Driver

        Returns:
          None
        """
        self._log_startup(input_dict, output_dict, exec_properties)

        # Fetch execution properties from exec_properties dict.
        string_parameter = exec_properties['string_execution_parameter']
        integer_parameter = exec_properties['integer_execution_parameter']

        # Fetch input URIs from input_dict.
        input_example_uri = types.get_single_uri(input_dict['input_example'])

        # Fetch output artifact from output_dict.
        output_example = types.get_single_instance(
            output_dict['output_example'])

        print("I AM RUNNING!")
        print(string_parameter)
        print(integer_parameter)
        print(input_example_uri)
        print(output_example)

        input_data = ""

        # load your input
        if tf.gfile.Exists(input_example_uri):
            with open(input_example_uri, "r") as file:
                input_data = file.read()

        # make some changes
        output_data = input_data + " changed by an awesome custom executor!"

        # update output uri for up stream components to know the filename
        output_example.uri = os.path.join(output_example.uri,
                                          _DEFAULT_FILE_NAME)

        # write the changes back to your output
        io_utils.write_string_file(output_example.uri, output_data)

        # you can also set custom properties to make checks in up stream components more quickly.
        # this is optional.
        output_example.set_string_custom_property('stringProperty', "Awesome")
        output_example.set_int_custom_property('intProperty', 42)
Пример #14
0
    def testRangeConfigSpanWidthPresence(self):
        # Test RangeConfig.static_range behavior when span width is not given.
        span1_split1 = os.path.join(self._input_base_path, 'span01', 'split1',
                                    'data')
        io_utils.write_string_file(span1_split1, 'testing11')

        range_config = range_config_pb2.RangeConfig(
            static_range=range_config_pb2.StaticRange(start_span_number=1,
                                                      end_span_number=1))
        splits1 = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='span{SPAN}/split1/*')
        ]

        # RangeConfig cannot find zero padding span without width modifier.
        with self.assertRaisesRegexp(ValueError,
                                     'Cannot find matching for split'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits1, range_config=range_config)

        splits2 = [
            example_gen_pb2.Input.Split(name='s1',
                                        pattern='span{SPAN:2}/split1/*')
        ]

        # With width modifier in span spec, RangeConfig.static_range makes
        # correct zero-padded substitution.
        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits2, range_config=range_config)
        self.assertEqual(span, 1)
        self.assertIsNone(version)
        self.assertEqual(splits2[0].pattern, 'span01/split1/*')
Пример #15
0
    def testStubExecutor(self, mock_publisher):
        # verify whether base stub executor substitution works
        mock_publisher.return_value.publish_execution.return_value = {}

        record_file = os.path.join(self.record_dir, self.component.id,
                                   self.output_key, '0', 'recorded.txt')
        io_utils.write_string_file(record_file, 'hello world')

        stub_component_launcher.StubComponentLauncher.initialize(
            test_data_dir=self.record_dir, test_component_ids=[])

        launcher = stub_component_launcher.StubComponentLauncher.create(
            component=self.component,
            pipeline_info=self.pipeline_info,
            driver_args=self.driver_args,
            metadata_connection=self.metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        launcher.launch()

        output_path = self.component.outputs[self.output_key].get()[0].uri
        copied_file = os.path.join(output_path, 'recorded.txt')
        self.assertTrue(fileio.exists(copied_file))
        contents = io_utils.read_string_file(copied_file)
        self.assertEqual('hello world', contents)
Пример #16
0
    def testExecutor(self, mock_publisher):
        # verify whether original executors can run
        mock_publisher.return_value.publish_execution.return_value = {}

        io_utils.write_string_file(os.path.join(self.input_dir, 'result.txt'),
                                   'test')

        stub_component_launcher.StubComponentLauncher.initialize(
            test_data_dir=self.record_dir,
            test_component_ids=[self.component.id])

        launcher = stub_component_launcher.StubComponentLauncher.create(
            component=self.component,
            pipeline_info=self.pipeline_info,
            driver_args=self.driver_args,
            metadata_connection=self.metadata_connection,
            beam_pipeline_args=[],
            additional_pipeline_args={})
        self.assertEqual(
            launcher._component_info.component_type,  # pylint: disable=protected-access
            '.'.join([  # pylint: disable=protected-access
                test_utils._FakeComponent.__module__,  # pylint: disable=protected-access
                test_utils._FakeComponent.__name__  # pylint: disable=protected-access
            ]))
        launcher.launch()

        output_path = self.component.outputs[self.output_key].get()[0].uri
        self.assertTrue(fileio.exists(output_path))
        contents = io_utils.read_string_file(output_path)
        self.assertEqual('test', contents)
Пример #17
0
    def testVersionWidth(self):
        split1 = os.path.join(self._input_base_path, 'span1', 'ver1', 'split1',
                              'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span{SPAN}/ver{VERSION:2}/split1/*')
        ]

        # TODO(jjma): find a better way of describing this error to user.
        with self.assertRaisesRegexp(
                ValueError, 'Glob pattern does not match regex pattern'):
            utils.calculate_splits_fingerprint_span_and_version(
                self._input_base_path, splits)

        splits = [
            example_gen_pb2.Input.Split(
                name='s1', pattern='span{SPAN}/ver{VERSION:1}/split1/*')
        ]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 1)
        self.assertEqual(version, 1)
Пример #18
0
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
           output_dict: Dict[str, List[types.Artifact]],
           exec_properties: Dict[str, Any]) -> None:

        if tfx_tuner.get_tune_args(exec_properties):
            raise ValueError(
                "TuneArgs is not supported by this Tuner's Executor.")

        metalearning_algorithm = None
        if 'metalearning_algorithm' in exec_properties:
            metalearning_algorithm = exec_properties.get(
                'metalearning_algorithm')

        warmup_trials = 0
        warmup_trial_data = None
        if metalearning_algorithm:
            warmup_tuner, warmup_trials = self.warmup(input_dict,
                                                      exec_properties,
                                                      metalearning_algorithm)
            warmup_trial_data = extract_tuner_trial_progress(warmup_tuner)
        else:
            logging.info('MetaLearning Algorithm not provided.')

        # Create new fn_args for final tuning stage.
        fn_args = fn_args_utils.get_common_fn_args(
            input_dict, exec_properties, working_dir=self._get_tmp_dir())
        tuner_fn = udf_utils.get_fn(exec_properties, 'tuner_fn')
        tuner_fn_result = tuner_fn(fn_args)
        tuner_fn_result.tuner.oracle.max_trials = max(
            (tuner_fn_result.tuner.oracle.max_trials - warmup_trials), 1)
        tuner = self.search(tuner_fn_result)
        tuner_trial_data = extract_tuner_trial_progress(tuner)

        if warmup_trial_data:
            cumulative_tuner_trial_data, best_tuner_ix = merge_trial_data(
                warmup_trial_data, tuner_trial_data)
            cumulative_tuner_trial_data[
                'warmup_trial_data'] = warmup_trial_data[BEST_CUMULATIVE_SCORE]
            cumulative_tuner_trial_data['tuner_trial_data'] = tuner_trial_data[
                BEST_CUMULATIVE_SCORE]

            if isinstance(tuner.oracle.objective, kerastuner.Objective):
                cumulative_tuner_trial_data[
                    'objective'] = tuner.oracle.objective.name
            else:
                cumulative_tuner_trial_data[
                    'objective'] = 'objective not understood'

            tuner_trial_data = cumulative_tuner_trial_data
            best_tuner = warmup_tuner if best_tuner_ix == 0 else tuner
        else:
            best_tuner = tuner
        tfx_tuner.write_best_hyperparameters(best_tuner, output_dict)
        tuner_plot_path = os.path.join(
            artifact_utils.get_single_uri(output_dict['trial_summary_plot']),
            'tuner_plot_data.txt')
        io_utils.write_string_file(tuner_plot_path,
                                   json.dumps(tuner_trial_data))
        logging.info('Tuner plot data written at: %s', tuner_plot_path)
Пример #19
0
  def testSpanWrongFormat(self):
    wrong_span = os.path.join(self._input_base_path, 'spanx', 'split1', 'data')
    io_utils.write_string_file(wrong_span, 'testing_wrong_span')

    with self.assertRaisesRegexp(ValueError, 'Cannot not find span number'):
      self._example_gen_driver.resolve_input_artifacts(self._input_channels,
                                                       self._exec_properties,
                                                       None, None)
Пример #20
0
  def testResolveInputArtifacts(self):
    # Create input splits.
    split1 = os.path.join(self._input_base_path, 'split1', 'data')
    io_utils.write_string_file(split1, 'testing')
    os.utime(split1, (0, 1))
    split2 = os.path.join(self._input_base_path, 'split2', 'data')
    io_utils.write_string_file(split2, 'testing2')
    os.utime(split2, (0, 3))

    # Mock artifact.
    artifacts = []
    for i in [4, 3, 2, 1]:
      artifact = metadata_store_pb2.Artifact()
      artifact.id = i
      artifact.uri = self._input_base_path
      artifact.custom_properties['span'].string_value = '0'
      # Only odd ids will be matched
      if i % 2 == 1:
        artifact.custom_properties[
            'input_fingerprint'].string_value = 'split:s1,num_files:1,total_bytes:7,xor_checksum:1,sum_checksum:1\nsplit:s2,num_files:1,total_bytes:8,xor_checksum:3,sum_checksum:3'
      else:
        artifact.custom_properties[
            'input_fingerprint'].string_value = 'not_match'
      artifacts.append(artifact)

    # Create exec proterties.
    exec_properties = {
        'input_config':
            json_format.MessageToJson(
                example_gen_pb2.Input(splits=[
                    example_gen_pb2.Input.Split(name='s1', pattern='split1/*'),
                    example_gen_pb2.Input.Split(name='s2', pattern='split2/*')
                ]),
                preserving_proto_field_name=True),
    }

    # Cache not hit.
    self._mock_metadata.get_artifacts_by_uri.return_value = [artifacts[0]]
    self._mock_metadata.publish_artifacts.return_value = [artifacts[3]]
    updated_input_dict = self._example_gen_driver.resolve_input_artifacts(
        self._input_channels, exec_properties, None, None)
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input_base']))
    updated_input_base = updated_input_dict['input_base'][0]
    self.assertEqual(1, updated_input_base.id)
    self.assertEqual(self._input_base_path, updated_input_base.uri)

    # Cache hit.
    self._mock_metadata.get_artifacts_by_uri.return_value = artifacts
    self._mock_metadata.publish_artifacts.return_value = []
    updated_input_dict = self._example_gen_driver.resolve_input_artifacts(
        self._input_channels, exec_properties, None, None)
    self.assertEqual(1, len(updated_input_dict))
    self.assertEqual(1, len(updated_input_dict['input_base']))
    updated_input_base = updated_input_dict['input_base'][0]
    self.assertEqual(3, updated_input_base.id)
    self.assertEqual(self._input_base_path, updated_input_base.uri)
Пример #21
0
 def testCopyFile(self):
     file_path = os.path.join(self._base_dir, 'temp_file')
     io_utils.write_string_file(file_path, 'testing')
     copy_path = os.path.join(self._base_dir, 'copy_file')
     io_utils.copy_file(file_path, copy_path)
     self.assertTrue(file_io.file_exists(copy_path))
     f = file_io.FileIO(file_path, mode='r')
     self.assertEqual('testing', f.read())
     self.assertEqual(7, f.tell())
Пример #22
0
 def testCopyDir(self):
   old_path = os.path.join(self._base_dir, 'old', 'path')
   new_path = os.path.join(self._base_dir, 'new', 'path')
   io_utils.write_string_file(old_path, 'testing')
   io_utils.copy_dir(os.path.dirname(old_path), os.path.dirname(new_path))
   self.assertTrue(file_io.file_exists(new_path))
   f = file_io.FileIO(new_path, mode='r')
   self.assertEqual('testing', f.read())
   self.assertEqual(7, f.tell())
Пример #23
0
 def testCopyFile(self):
   file_path = os.path.join(self._base_dir, 'temp_file')
   io_utils.write_string_file(file_path, 'testing')
   copy_path = os.path.join(self._base_dir, 'copy_file')
   io_utils.copy_file(file_path, copy_path)
   self.assertTrue(file_io.file_exists(copy_path))
   f = file_io.FileIO(file_path, mode='r')
   self.assertEqual('testing', f.read())
   self.assertEqual(7, f.tell())
Пример #24
0
 def testCopyDir(self):
     old_path = os.path.join(self._base_dir, 'old', 'path')
     new_path = os.path.join(self._base_dir, 'new', 'path')
     io_utils.write_string_file(old_path, 'testing')
     io_utils.copy_dir(os.path.dirname(old_path), os.path.dirname(new_path))
     self.assertTrue(file_io.file_exists(new_path))
     f = file_io.FileIO(new_path, mode='r')
     self.assertEqual('testing', f.read())
     self.assertEqual(7, f.tell())
Пример #25
0
def copy_and_change_pipeline_name(orig_path: str, new_path: str,
                                  origin_pipeline_name: str,
                                  new_pipeline_name: str) -> None:
    """Copy pipeline file to new path with pipeline name changed."""
    contents = io_utils.read_string_file(orig_path)
    assert contents.count(origin_pipeline_name
                          ) == 1, 'DSL file can only contain one pipeline name'
    contents = contents.replace(origin_pipeline_name, new_pipeline_name)
    io_utils.write_string_file(new_path, contents)
Пример #26
0
 def Do(self, input_dict: Dict[Text, List[types.Artifact]],
        output_dict: Dict[Text, List[types.Artifact]],
        exec_properties: Dict[Text, Any]) -> None:
   # Don't call super().Do() to skip copying.
   absl.logging.info('Running CustomStubExecutor')
   for artifact_list in output_dict.values():
     for artifact in artifact_list:
       custom_output_path = os.path.join(artifact.uri, 'result.txt')
       io_utils.write_string_file(custom_output_path, 'custom component')
Пример #27
0
    def _uncommentMultiLineVariables(self, filepath: Text,
                                     variables: Iterable[Text]) -> Text:
        """Update given file by uncommenting a variable.

    The variable should be defined in following form.
    # ....
    # VARIABLE_NAME = ...
    #   long indented line
    #
    #   long indented line
    # OTHER STUFF

    Above comments will become

    # ....
    VARIABLE_NAME = ...
      long indented line

      long indented line
    # OTHER STUFF

    Arguments:
      filepath: file to modify.
      variables: List of variables.

    Returns:
      Absolute path of the modified file.
    """
        path = os.path.join(self._project_dir, filepath)
        result = []
        commented_variables = [
            '# ' + variable + ' =' for variable in variables
        ]
        in_variable_definition = False

        with open(path) as fp:
            for line in fp:
                if in_variable_definition:
                    if line.startswith('#  ') or line.startswith('# }'):
                        result.append(line[2:])
                        continue
                    elif line == '#\n':
                        result.append(line[1:])
                        continue
                    else:
                        in_variable_definition = False
                for commented_var in commented_variables:
                    if line.startswith(commented_var):
                        in_variable_definition = True
                        result.append(line[2:])
                        break
                else:
                    # doesn't include a variable definition to uncomment.
                    result.append(line)

        io_utils.write_string_file(path, ''.join(result))
        return path
Пример #28
0
 def _replaceFileContent(self, filepath: Text,
                         replacements: Iterable[Tuple[Text, Text]]) -> Text:
     """Update given file using `replacements`."""
     path = os.path.join(self._project_dir, filepath)
     with open(path) as fp:
         content = fp.read()
     for old, new in replacements:
         content = content.replace(old, new)
     io_utils.write_string_file(path, content)
     return path
Пример #29
0
  def testSpanWrongFormat(self):
    wrong_span = os.path.join(self._input_base_path, 'spanx', 'split1', 'data')
    io_utils.write_string_file(wrong_span, 'testing_wrong_span')

    splits = [
        example_gen_pb2.Input.Split(name='s1', pattern='span{SPAN}/split1/*')
    ]
    with self.assertRaisesRegex(ValueError, 'Cannot find span number'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
 def _addAllComponents(self):
     """Change 'pipeline.py' file to put all components into the pipeline."""
     pipeline_definition_file = os.path.join(self._project_dir,
                                             'pipeline.py')
     with open(pipeline_definition_file) as fp:
         content = fp.read()
     # At the initial state, these are commented out. Uncomment them.
     content = content.replace('# components.append(', 'components.append(')
     io_utils.write_string_file(pipeline_definition_file, content)
     return pipeline_definition_file
Пример #31
0
  def testVersionNoMatching(self):
    span_dir = os.path.join(self._input_base_path, 'span01', 'wrong', 'data')
    io_utils.write_string_file(span_dir, 'testing_version_no_matching')

    splits = [
        example_gen_pb2.Input.Split(
            name='s1', pattern='span{SPAN}/version{VERSION}/split1/*')
    ]
    with self.assertRaisesRegex(ValueError, 'Cannot find matching for split'):
      utils.calculate_splits_fingerprint_span_and_version(
          self._input_base_path, splits)
Пример #32
0
    def testNoSpanOrVersion(self):
        # Test specific behavior when neither Span nor Version spec is present.
        split1 = os.path.join(self._input_base_path, 'split1', 'data')
        io_utils.write_string_file(split1, 'testing')

        splits = [example_gen_pb2.Input.Split(name='s1', pattern='split1/*')]

        _, span, version = utils.calculate_splits_fingerprint_span_and_version(
            self._input_base_path, splits)
        self.assertEqual(span, 0)
        self.assertIsNone(version)
Пример #33
0
def build_ephemeral_package() -> Text:
  """Repackage current installation of TFX into a tfx_ephemeral sdist.

  Returns:
    Path to ephemeral sdist package.
  Raises:
    RuntimeError: if dist directory has zero or multiple files.
  """
  tmp_dir = os.path.join(tempfile.mkdtemp(), 'build', 'tfx')
  # Find the last directory named 'tfx' in this file's path and package it.
  path_split = __file__.split(os.path.sep)
  last_index = -1
  for i in range(len(path_split)):
    if path_split[i] == 'tfx':
      last_index = i
  if last_index < 0:
    raise RuntimeError('Cannot locate directory \'tfx\' in the path %s' %
                       __file__)
  tfx_root_dir = os.path.sep.join(path_split[0:last_index + 1])
  absl.logging.info('Copying all content from install dir %s to temp dir %s',
                    tfx_root_dir, tmp_dir)
  shutil.copytree(tfx_root_dir, os.path.join(tmp_dir, 'tfx'))
  # Source directory default permission is 0555 but we need to be able to create
  # new setup.py file.
  os.chmod(tmp_dir, 0o720)
  setup_file = os.path.join(tmp_dir, 'setup.py')
  absl.logging.info('Generating a temp setup file at %s', setup_file)
  install_requires = dependencies.make_required_install_packages()
  io_utils.write_string_file(
      setup_file,
      _ephemeral_setup_file.format(
          version=version.__version__, install_requires=install_requires))

  # Create the package
  curdir = os.getcwd()
  os.chdir(tmp_dir)
  temp_log = os.path.join(tmp_dir, 'setup.log')
  with open(temp_log, 'w') as f:
    absl.logging.info('Creating temporary sdist package, logs available at %s',
                      temp_log)
    cmd = [sys.executable, setup_file, 'sdist']
    subprocess.call(cmd, stdout=f, stderr=f)
  os.chdir(curdir)

  # Return the package dir+filename
  dist_dir = os.path.join(tmp_dir, 'dist')
  files = tf.io.gfile.listdir(dist_dir)
  if not files:
    raise RuntimeError('Found no package files in %s' % dist_dir)
  elif len(files) > 1:
    raise RuntimeError('Found multiple package files in %s' % dist_dir)

  return os.path.join(dist_dir, files[0])
Пример #34
0
 def testGetOnlyFileInDir(self):
   file_path = os.path.join(self._base_dir, 'file', 'path')
   io_utils.write_string_file(file_path, 'testing')
   self.assertEqual(file_path,
                    io_utils.get_only_uri_in_dir(os.path.dirname(file_path)))
Пример #35
0
 def testDeleteDir(self):
   file_path = os.path.join(self._base_dir, 'file', 'path')
   io_utils.write_string_file(file_path, 'testing')
   self.assertTrue(tf.gfile.Exists(file_path))
   io_utils.delete_dir(os.path.dirname(file_path))
   self.assertFalse(tf.gfile.Exists(file_path))