Exemple #1
0
 def _verify_output(self):
     # Test best hparams.
     best_hparams_path = os.path.join(self._best_hparams.uri,
                                      'best_hyperparameters.txt')
     self.assertTrue(fileio.exists(best_hparams_path))
     best_hparams_config = json.loads(
         file_io.read_file_to_string(best_hparams_path))
     best_hparams = HyperParameters.from_config(best_hparams_config)
     self.assertIn(best_hparams.get('learning_rate'), (1e-1, 1e-3))
     self.assertBetween(best_hparams.get('num_layers'), 1, 5)
Exemple #2
0
 def testSavePipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = local_handler.LocalHandler(flags_dict)
     handler._save_pipeline({labels.PIPELINE_NAME: self.pipeline_name})
     self.assertTrue(
         fileio.exists(
             os.path.join(handler._handler_home_dir, self.pipeline_name)))
Exemple #3
0
 def _assert_infra_validator_passed(self, pipeline_name: Text):
     artifacts = self._get_artifacts_with_type_and_pipeline(
         type_name='InfraBlessing', pipeline_name=pipeline_name)
     self.assertGreaterEqual(len(artifacts), 1)
     for artifact in artifacts:
         blessed = os.path.join(artifact.uri, 'INFRA_BLESSED')
         self.assertTrue(
             fileio.exists(blessed),
             'Expected InfraBlessing results cannot be found under path %s for '
             'artifact %s' % (blessed, artifact))
Exemple #4
0
    def testDoWithBlessedModel(self):
        # Create exe properties.
        exec_properties = {
            'blessed_model': os.path.join(self._source_data_dir,
                                          'trainer/blessed'),
            'blessed_model_id': 123,
            'current_component_id': self.component_id,
        }

        # Run executor.
        model_validator = executor.Executor(self._context)
        model_validator.Do(self._input_dict, self._output_dict,
                           exec_properties)

        # Check model validator outputs.
        self.assertTrue(fileio.exists(os.path.join(self._tmp_dir)))
        self.assertTrue(
            fileio.exists(
                os.path.join(self._blessing.uri, constants.BLESSED_FILE_NAME)))
    def testInvokeTFLiteRewriterWithAssetsSucceeds(self, converter):
        m = self.ConverterMock()
        converter.return_value = m

        src_model_path = tempfile.mkdtemp()
        dst_model_path = tempfile.mkdtemp()

        saved_model_path = os.path.join(
            src_model_path, tf.saved_model.SAVED_MODEL_FILENAME_PBTXT)
        with fileio.open(saved_model_path, 'wb') as f:
            f.write(six.ensure_binary('saved_model'))

        assets_dir = os.path.join(src_model_path,
                                  tf.saved_model.ASSETS_DIRECTORY)
        fileio.mkdir(assets_dir)
        assets_file_path = os.path.join(assets_dir, 'assets_file')
        with fileio.open(assets_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_file'))

        assets_extra_dir = os.path.join(src_model_path, EXTRA_ASSETS_DIRECTORY)
        fileio.mkdir(assets_extra_dir)
        assets_extra_file_path = os.path.join(assets_extra_dir,
                                              'assets_extra_file')
        with fileio.open(assets_extra_file_path, 'wb') as f:
            f.write(six.ensure_binary('assets_extra_file'))

        src_model = rewriter.ModelDescription(rewriter.ModelType.SAVED_MODEL,
                                              src_model_path)
        dst_model = rewriter.ModelDescription(rewriter.ModelType.TFLITE_MODEL,
                                              dst_model_path)

        tfrw = tflite_rewriter.TFLiteRewriter(name='myrw',
                                              filename='fname',
                                              enable_quantization=True)
        tfrw.perform_rewrite(src_model, dst_model)

        converter.assert_called_once_with(saved_model_path=mock.ANY,
                                          enable_quantization=True)
        expected_model = os.path.join(dst_model_path, 'fname')
        self.assertTrue(fileio.exists(expected_model))
        with fileio.open(expected_model, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'model')

        expected_assets_file = os.path.join(dst_model_path,
                                            tf.saved_model.ASSETS_DIRECTORY,
                                            'assets_file')
        with fileio.open(expected_assets_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()), 'assets_file')

        expected_assets_extra_file = os.path.join(dst_model_path,
                                                  EXTRA_ASSETS_DIRECTORY,
                                                  'assets_extra_file')
        with fileio.open(expected_assets_extra_file, 'rb') as f:
            self.assertEqual(six.ensure_text(f.readline()),
                             'assets_extra_file')
Exemple #6
0
 def testPipelineSavePipelineArgs(self):
   os.environ['TFX_JSON_EXPORT_PIPELINE_ARGS_PATH'] = self._tmp_file
   pipeline.Pipeline(
       pipeline_name='a',
       pipeline_root='b',
       log_root='c',
       components=[
           _make_fake_component_instance('component_a', _OutputTypeA, {}, {})
       ],
       metadata_connection_config=self._metadata_connection_config)
   self.assertTrue(fileio.exists(self._tmp_file))
Exemple #7
0
 def _valid_create_and_check(self, pipeline_path: Text,
                             pipeline_name: Text) -> None:
   result = self.runner.invoke(cli_group, [
       'pipeline', 'create', '--engine', 'kubeflow', '--pipeline_path',
       pipeline_path, '--endpoint', self._endpoint
   ])
   absl.logging.info('[CLI] %s', result.output)
   self.assertIn('Creating pipeline', result.output)
   self.assertTrue(fileio.exists(self._pipeline_package_path))
   self.assertIn('Pipeline "{}" created successfully.'.format(pipeline_name),
                 result.output)
Exemple #8
0
  def _check_pipeline_package_path(self, pipeline_name: Text) -> None:
    # When unset, search for the workflow file in the current dir.
    if not self.flags_dict[labels.PIPELINE_PACKAGE_PATH]:
      self.flags_dict[labels.PIPELINE_PACKAGE_PATH] = os.path.join(
          os.getcwd(), '{}.tar.gz'.format(pipeline_name))

    pipeline_package_path = self.flags_dict[labels.PIPELINE_PACKAGE_PATH]
    if not fileio.exists(pipeline_package_path):
      sys.exit(
          'Pipeline package not found at {}. When --package_path is unset, it will try to find the workflow file, "<pipeline_name>.tar.gz" in the current directory.'
          .format(pipeline_package_path))
Exemple #9
0
    def list_pipelines(self) -> None:
        """List all the pipelines in the environment."""
        if not fileio.exists(self._handler_home_dir):
            click.echo('No pipelines to display.')
            return
        pipelines_list = fileio.listdir(self._handler_home_dir)

        # Print every pipeline name in a new line.
        click.echo('-' * 30)
        click.echo('\n'.join(pipelines_list))
        click.echo('-' * 30)
 def assertInfraValidatorPassed(self) -> None:
   infra_validator_path = os.path.join(self._pipeline_root, 'InfraValidator')
   blessing_path = os.path.join(self._pipeline_root, 'InfraValidator',
                                'blessing')
   executions = fileio.listdir(blessing_path)
   self.assertGreaterEqual(len(executions), 1)
   for exec_id in executions:
     blessing_uri = base_driver._generate_output_uri(  # pylint: disable=protected-access
         infra_validator_path, 'blessing', exec_id)
     blessed = os.path.join(blessing_uri, 'INFRA_BLESSED')
     self.assertTrue(fileio.exists(blessed))
Exemple #11
0
 def _verify_example_split(self, split_name):
     self.assertTrue(
         fileio.exists(
             os.path.join(self._output_examples_dir,
                          f'Split-{split_name}')))
     results = self._get_results(
         os.path.join(self._output_examples_dir, f'Split-{split_name}'),
         executor._EXAMPLES_FILE_NAME, tf.train.Example)
     self.assertTrue(results)
     self.assertIn('classify_label', results[0].features.feature)
     self.assertIn('classify_score', results[0].features.feature)
Exemple #12
0
    def _check_pipeline_existence(self,
                                  pipeline_name: Text,
                                  required: bool = True) -> None:
        """Check if pipeline folder exists and if not, exit system.

    Args:
      pipeline_name: Name of the pipeline.
      required: Set it as True if pipeline needs to exist else set it to False.
    """
        handler_pipeline_path = os.path.join(self._handler_home_dir,
                                             pipeline_name)
        # Check if pipeline folder exists.
        exists = fileio.exists(handler_pipeline_path)
        if required and not exists:
            # Check pipeline directory prior 0.25 and move files to the new location
            # automatically.
            old_handler_pipeline_path = os.path.join(
                self._get_deprecated_handler_home(), pipeline_name)
            if fileio.exists(old_handler_pipeline_path):
                fileio.makedirs(os.path.dirname(handler_pipeline_path))
                fileio.rename(old_handler_pipeline_path, handler_pipeline_path)
                engine_flag = self.flags_dict[labels.ENGINE_FLAG]
                handler_home_variable = engine_flag.upper() + '_HOME'
                click.echo((
                    '[WARNING] Pipeline "{pipeline_name}" was found in "{old_path}", '
                    'but the location that TFX stores pipeline information was moved '
                    'since TFX 0.25.0.\n'
                    '[WARNING] Your files in "{old_path}" was automatically moved to '
                    'the new location, "{new_path}".\n'
                    '[WARNING] If you want to keep the files at the old location, set '
                    '`{handler_home}` environment variable to "{old_handler_home}".'
                ).format(pipeline_name=pipeline_name,
                         old_path=old_handler_pipeline_path,
                         new_path=handler_pipeline_path,
                         handler_home=handler_home_variable,
                         old_handler_home=self._get_deprecated_handler_home()),
                           err=True)
            else:
                sys.exit('Pipeline "{}" does not exist.'.format(pipeline_name))
        elif not required and exists:
            sys.exit('Pipeline "{}" already exists.'.format(pipeline_name))
Exemple #13
0
  def _testDo(self, payload_format):
    exec_properties = {
        utils.INPUT_BASE_KEY: self._input_data_dir,
        utils.INPUT_CONFIG_KEY: self._input_config,
        utils.OUTPUT_CONFIG_KEY: self._output_config,
        utils.OUTPUT_DATA_FORMAT_KEY: payload_format,
    }

    output_data_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Create output dict.
    self.examples = standard_artifacts.Examples()
    self.examples.uri = output_data_dir
    output_dict = {utils.EXAMPLES_KEY: [self.examples]}

    # Run executor.
    import_example_gen = executor.Executor()
    import_example_gen.Do({}, output_dict, exec_properties)

    self.assertEqual(
        artifact_utils.encode_split_names(['train', 'eval']),
        self.examples.split_names)

    # Check import_example_gen outputs.
    train_output_file = os.path.join(self.examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(self.examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')

    # Check import_example_gen outputs.
    train_output_file = os.path.join(self.examples.uri, 'train',
                                     'data_tfrecord-00000-of-00001.gz')
    eval_output_file = os.path.join(self.examples.uri, 'eval',
                                    'data_tfrecord-00000-of-00001.gz')
    self.assertTrue(fileio.exists(train_output_file))
    self.assertTrue(fileio.exists(eval_output_file))
    self.assertGreater(
        fileio.open(train_output_file).size(),
        fileio.open(eval_output_file).size())
Exemple #14
0
  def read(self):
    if not self._has_value:
      file_path = self.uri
      # Assert there is a file exists.
      if not fileio.exists(file_path):
        raise RuntimeError(
            'Given path does not exist or is not a valid file: %s' % file_path)

      serialized_value = fileio.open(file_path, 'rb').read()
      self._has_value = True
      self._value = self.decode(serialized_value)
    return self._value
Exemple #15
0
  def _assertHyperparametersAreWritten(self, pipeline_name):
    """Make sure the tuner execution and hyperpearameters output."""
    # There must be only one execution of Tuner.
    tuner_output_base_dir = os.path.join(
        self._pipeline_root(pipeline_name), 'Tuner', 'best_hyperparameters')
    tuner_outputs = fileio.listdir(tuner_output_base_dir)
    self.assertEqual(1, len(tuner_outputs))

    # There must be only one best hyperparameters.
    best_hyperparameters_uri = os.path.join(tuner_output_base_dir,
                                            tuner_outputs[0])
    self.assertTrue(fileio.exists(best_hyperparameters_uri))
 def testCreatePipeline(self):
   flags_dict = {
       labels.ENGINE_FLAG: self.engine,
       labels.PIPELINE_DSL_PATH: self.pipeline_path
   }
   handler = kubeflow_v2_handler.KubeflowV2Handler(flags_dict)
   handler.create_pipeline()
   handler_pipeline_path = os.path.join(
       handler._handler_home_dir, self.pipeline_args[labels.PIPELINE_NAME], '')
   self.assertTrue(
       fileio.exists(
           os.path.join(handler_pipeline_path, 'pipeline_args.json')))
Exemple #17
0
 def testSavePipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = airflow_handler.AirflowHandler(flags_dict)
     pipeline_args = handler._extract_pipeline_args()
     handler._save_pipeline(pipeline_args)
     self.assertTrue(
         fileio.exists(
             os.path.join(handler._handler_home_dir,
                          self.pipeline_args[labels.PIPELINE_NAME])))
    def testTaxiPipelineWarmstart(self):
        BeamDagRunner().run(
            taxi_pipeline_warmstart._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(10, execution_count)

        self.assertPipelineExecution()

        # Run pipeline again.
        BeamDagRunner().run(
            taxi_pipeline_warmstart._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        with metadata.Metadata(metadata_config) as m:
            # 10 more executions.
            self.assertEqual(20, len(m.store.get_executions()))

        # Two trainer outputs.
        self.assertExecutedTwice('Trainer')
Exemple #19
0
    def testDoValidation(self, exec_properties, blessed, has_baseline):
        source_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
            self._testMethodName)

        # Create input dict.
        examples = standard_artifacts.Examples()
        examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
        examples.split_names = artifact_utils.encode_split_names(
            ['train', 'eval'])
        model = standard_artifacts.Model()
        baseline_model = standard_artifacts.Model()
        model.uri = os.path.join(source_data_dir, 'trainer/current')
        baseline_model.uri = os.path.join(source_data_dir, 'trainer/previous/')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        schema = standard_artifacts.Schema()
        schema.uri = os.path.join(source_data_dir, 'schema_gen')
        input_dict = {
            EXAMPLES_KEY: [examples],
            MODEL_KEY: [model],
            SCHEMA_KEY: [schema],
        }
        if has_baseline:
            input_dict[BASELINE_MODEL_KEY] = [baseline_model]

        # Create output dict.
        eval_output = standard_artifacts.ModelEvaluation()
        eval_output.uri = os.path.join(output_data_dir, 'eval_output')
        blessing_output = standard_artifacts.ModelBlessing()
        blessing_output.uri = os.path.join(output_data_dir, 'blessing_output')
        output_dict = {
            EVALUATION_KEY: [eval_output],
            BLESSING_KEY: [blessing_output],
        }

        # List needs to be serialized before being passed into Do function.
        exec_properties[EXAMPLE_SPLITS_KEY] = json_utils.dumps(None)

        # Run executor.
        evaluator = executor.Executor()
        evaluator.Do(input_dict, output_dict, exec_properties)

        # Check evaluator outputs.
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'eval_config.json')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri,
                                                   'metrics')))
        self.assertTrue(fileio.exists(os.path.join(eval_output.uri, 'plots')))
        self.assertTrue(
            fileio.exists(os.path.join(eval_output.uri, 'validations')))
        if blessed:
            self.assertTrue(
                fileio.exists(os.path.join(blessing_output.uri, 'BLESSED')))
        else:
            self.assertTrue(
                fileio.exists(os.path.join(blessing_output.uri,
                                           'NOT_BLESSED')))
Exemple #20
0
    def resolve(self, pipeline_root: Text):
        # Package the given user module file as a Python wheel.
        module_file = self.component.spec.exec_properties[self.module_file_key]

        # Perform validation on the given `module_file`.
        if not module_file:
            return None
        elif not isinstance(module_file, Text):
            # TODO(b/187753042): Deprecate and remove usage of RuntimeParameters for
            # `module_file` parameters and remove this code path.
            logging.warning(
                'Module file %r for component %s is not a path string; '
                'skipping Python user module wheel packaging.', module_file,
                self.component)
            return None
        elif not fileio.exists(module_file):
            raise ValueError(
                'Specified module file %r for component %s does not exist.' %
                (module_file, self.component))

        # Perform validation on the `pipeline_root`.
        if not pipeline_root:
            logging.warning(
                'No pipeline root provided; skipping Python user module '
                'wheel packaging for component %s.', self.component)
            return None
        pipeline_root_exists = fileio.exists(pipeline_root)
        if not pipeline_root_exists:
            fileio.makedirs(pipeline_root)

        # Perform packaging of the user module.
        dist_file_path, user_module_path = package_user_module_file(
            self.component.id, module_file, pipeline_root)

        # Set the user module key to point to a module in this wheel, and clear the
        # module path key before returning.
        self.component.spec.exec_properties[
            self.module_path_key] = user_module_path
        self.component.spec.exec_properties[self.module_file_key] = None
        return dist_file_path
Exemple #21
0
    def testGetSchema(self):
        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_DSL_PATH: self.pipeline_path,
            labels.ENDPOINT: self.endpoint,
            labels.IAP_CLIENT_ID: self.iap_client_id,
            labels.NAMESPACE: self.namespace,
            labels.PIPELINE_PACKAGE_PATH: self.pipeline_package_path
        }
        handler = kubeflow_handler.KubeflowHandler(flags_dict)
        handler.create_pipeline()

        flags_dict = {
            labels.ENGINE_FLAG: self.engine,
            labels.PIPELINE_NAME: self.pipeline_name,
        }

        # No pipeline root
        handler = kubeflow_handler.KubeflowHandler(flags_dict)
        with self.assertRaises(SystemExit) as err:
            handler.get_schema()
        self.assertEqual(
            str(err.exception),
            'Create a run before inferring schema. If pipeline is already running, then wait for it to successfully finish.'
        )

        # No SchemaGen output.
        fileio.makedirs(self.pipeline_root)
        with self.assertRaises(SystemExit) as err:
            handler.get_schema()
        self.assertEqual(
            str(err.exception),
            'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.'
        )

        # Successful pipeline run.
        # Create fake schema in pipeline root.
        component_output_dir = os.path.join(self.pipeline_root, 'SchemaGen')
        schema_path = base_driver._generate_output_uri(  # pylint: disable=protected-access
            component_output_dir, 'schema', 3)
        fileio.makedirs(schema_path)
        with open(os.path.join(schema_path, 'schema.pbtxt'), 'w') as f:
            f.write('SCHEMA')
        with self.captureWritesToStream(sys.stdout) as captured:
            handler.get_schema()
            curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt')
            self.assertIn('Path to schema: {}'.format(curr_dir_path),
                          captured.contents())
            self.assertIn(
                '*********SCHEMA FOR {}**********'.format(
                    self.pipeline_name.upper()), captured.contents())
            self.assertTrue(fileio.exists(curr_dir_path))
Exemple #22
0
    def _update_execution_proto(
        self,
        execution: metadata_store_pb2.Execution,
        pipeline_info: Optional[data_types.PipelineInfo] = None,
        component_info: Optional[data_types.ComponentInfo] = None,
        state: Optional[Text] = None,
        exec_properties: Optional[Dict[Text, Any]] = None,
    ) -> metadata_store_pb2.Execution:
        """Updates the execution proto with given type and state."""
        if state is not None:
            execution.properties[
                _EXECUTION_TYPE_KEY_STATE].string_value = tf.compat.as_text(
                    state)
        # Forward-compatible change to leverage built-in schema to track states.
        if state == EXECUTION_STATE_CACHED:
            execution.last_known_state = metadata_store_pb2.Execution.CACHED
        elif state == EXECUTION_STATE_COMPLETE:
            execution.last_known_state = metadata_store_pb2.Execution.COMPLETE
        elif state == EXECUTION_STATE_NEW:
            execution.last_known_state = metadata_store_pb2.Execution.RUNNING

        exec_properties = exec_properties or {}
        # TODO(ruoyu): Enforce a formal rule for execution schema change.
        for k, v in exec_properties.items():
            # We always convert execution properties to unicode.
            execution.properties[k].string_value = tf.compat.as_text(
                tf.compat.as_str_any(v))
        # We also need to checksum UDF file to identify different binary being
        # used. Do we have a better way to checksum a file than hashlib.md5?
        # TODO(ruoyu): Find a better place / solution to the checksum logic.
        # TODO(ruoyu): SHA instead of MD5.
        if 'module_file' in exec_properties and exec_properties[
                'module_file'] and fileio.exists(
                    exec_properties['module_file']):
            contents = file_io.read_file_to_string(
                exec_properties['module_file'])
            execution.properties[
                'checksum_md5'].string_value = tf.compat.as_text(
                    tf.compat.as_str_any(
                        hashlib.md5(tf.compat.as_bytes(contents)).hexdigest()))
        if pipeline_info:
            execution.properties[
                'pipeline_name'].string_value = pipeline_info.pipeline_name
            execution.properties[
                'pipeline_root'].string_value = pipeline_info.pipeline_root
            if pipeline_info.run_id:
                execution.properties[
                    'run_id'].string_value = pipeline_info.run_id
        if component_info:
            execution.properties[
                'component_id'].string_value = component_info.component_id
        return execution
Exemple #23
0
 def testCreatePipelineWithFlags(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = local_handler.LocalHandler(flags_dict)
     # Pipeline creation should not be affected by additional flags.
     with mock.patch.object(sys,
                            'argv',
                            new=(sys.argv + ['--unexpected_flag'])):
         handler.create_pipeline()
     self.assertTrue(
         fileio.exists(handler._get_pipeline_args_path(self.pipeline_name)))
Exemple #24
0
    def testIrisPipelineBeam(self):
        BeamDagRunner().run(
            iris_pipeline_beam._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(8, execution_count)  # 7 components + 1 resolver

        self.assertPipelineExecution()
    def testTaxiPipelineBeam(self):
        LocalDagRunner().run(
            taxi_pipeline_local._create_pipeline(
                pipeline_name=self._pipeline_name,
                data_root=self._data_root,
                module_file=self._module_file,
                serving_model_dir=self._serving_model_dir,
                pipeline_root=self._pipeline_root,
                metadata_path=self._metadata_path,
                beam_pipeline_args=[]))

        self.assertTrue(fileio.exists(self._serving_model_dir))
        self.assertTrue(fileio.exists(self._metadata_path))
        metadata_config = metadata.sqlite_metadata_connection_config(
            self._metadata_path)
        with metadata.Metadata(metadata_config) as m:
            artifact_count = len(m.store.get_artifacts())
            execution_count = len(m.store.get_executions())
            self.assertGreaterEqual(artifact_count, execution_count)
            self.assertEqual(10, execution_count)

        self.assertPipelineExecution()
Exemple #26
0
    def list_pipelines(self) -> None:
        """List all the pipelines in the environment."""
        # There is no managed storage for pipeline packages, so CLI consults
        # local dir to list pipelines.
        if not fileio.exists(self._handler_home_dir):
            click.echo('No pipelines to display.')
            return
        pipelines_list = fileio.listdir(self._handler_home_dir)

        # Print every pipeline name in a new line.
        click.echo('-' * 30)
        click.echo('\n'.join(pipelines_list))
        click.echo('-' * 30)
Exemple #27
0
    def test_do_with_empty_transform_splits(self):
        self._exec_properties['splits_config'] = proto_utils.proto_to_json(
            transform_pb2.SplitsConfig(analyze=['train'], transform=[]))
        self._exec_properties['module_file'] = self._module_file
        self._output_dict[executor.TRANSFORMED_EXAMPLES_KEY] = (
            self._transformed_example_artifacts[:1])

        self._transform_executor.Do(self._input_dict, self._output_dict,
                                    self._exec_properties)
        self.assertFalse(
            fileio.exists(
                os.path.join(self._transformed_example_artifacts[0].uri,
                             'train')))
        self.assertFalse(
            fileio.exists(
                os.path.join(self._transformed_example_artifacts[0].uri,
                             'eval')))
        path_to_saved_model = os.path.join(
            self._transformed_output.uri,
            tft.TFTransformOutput.TRANSFORM_FN_DIR,
            tf.saved_model.SAVED_MODEL_FILENAME_PB)
        self.assertTrue(fileio.exists(path_to_saved_model))
Exemple #28
0
 def testCreatePipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = airflow_handler.AirflowHandler(flags_dict)
     handler.create_pipeline()
     handler_pipeline_path = handler._get_pipeline_info_path(
         self.pipeline_name)
     self.assertTrue(
         fileio.exists(
             os.path.join(handler_pipeline_path,
                          'test_pipeline_airflow_1.py')))
Exemple #29
0
 def testCreatePipeline(self):
     flags_dict = {
         labels.ENGINE_FLAG: self.engine,
         labels.PIPELINE_DSL_PATH: self.pipeline_path
     }
     handler = kubeflow_v2_handler.KubeflowV2Handler(flags_dict)
     handler.create_pipeline()
     handler_pipeline_path = os.path.join(handler._handler_home_dir,
                                          self.pipeline_name)
     self.assertTrue(
         fileio.exists(
             os.path.join(handler_pipeline_path,
                          kubeflow_v2_dag_runner_patcher._OUTPUT_FILENAME)))
Exemple #30
0
    def testRunStatisticsGen(self):
        # Prepare the paths
        test_data_dir = os.path.join(
            os.path.dirname(os.path.dirname(__file__)), 'components',
            'testdata')
        output_data_dir = os.path.join(
            os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', tempfile.mkdtemp()),
            self._testMethodName)
        statistics_split_names_path = os.path.join(output_data_dir,
                                                   'statistics.properties',
                                                   'split_names')
        fileio.makedirs(output_data_dir)

        # Run StatisticsGen
        run_component.run_component(
            full_component_class_name='tfx.components.StatisticsGen',
            examples_uri=os.path.join(test_data_dir, 'csv_example_gen'),
            examples_split_names=artifact_utils.encode_split_names(
                ['train', 'eval']),
            # Testing that we can set non-string artifact properties
            examples_version='1',
            statistics_path=output_data_dir,
            statistics_split_names_path=statistics_split_names_path,
        )

        # Check the statistics_gen outputs
        self.assertTrue(
            fileio.exists(
                os.path.join(output_data_dir, 'Split-train',
                             'FeatureStats.pb')))
        self.assertTrue(
            fileio.exists(
                os.path.join(output_data_dir, 'Split-eval',
                             'FeatureStats.pb')))
        self.assertTrue(os.path.exists(statistics_split_names_path))
        self.assertEqual(
            pathlib.Path(statistics_split_names_path).read_text(),
            '["train", "eval"]')