def get_schema(self): pipeline_name = self.flags_dict[labels.PIPELINE_NAME] # Check if pipeline exists. self._check_pipeline_existence(pipeline_name) # Path to pipeline args. pipeline_args_path = os.path.join( self._handler_home_dir, self.flags_dict[labels.PIPELINE_NAME], 'pipeline_args.json') # Get pipeline_root. with open(pipeline_args_path, 'r') as f: pipeline_args = json.load(f) # Check if pipeline root created. If not, it means that the user has not # created a run yet or the pipeline is still running for the first time. pipeline_root = pipeline_args[labels.PIPELINE_ROOT] if not tf.io.gfile.exists(pipeline_root): sys.exit( 'Create a run before inferring schema. If pipeline is already running, then wait for it to successfully finish.' ) # If pipeline_root exists, then check if SchemaGen output exists. components = tf.io.gfile.listdir(pipeline_root) if 'SchemaGen' not in components: sys.exit( 'Either SchemaGen component does not exist or pipeline is still running. If pipeline is running, then wait for it to successfully finish.' ) # Get the latest SchemaGen output. component_output_dir = os.path.join(pipeline_root, 'SchemaGen') schema1_uri = base_driver._generate_output_uri( # pylint: disable=protected-access component_output_dir, 'schema', 1) schema_dir = os.path.join(os.path.dirname(schema1_uri), '') schemagen_outputs = tf.io.gfile.listdir(schema_dir) latest_schema_folder = max(schemagen_outputs, key=int) # Copy schema to current dir. latest_schema_uri = base_driver._generate_output_uri( # pylint: disable=protected-access component_output_dir, 'schema', latest_schema_folder) latest_schema_path = os.path.join(latest_schema_uri, 'schema.pbtxt') curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt') io_utils.copy_file(latest_schema_path, curr_dir_path, overwrite=True) # Print schema and path to schema click.echo('Path to schema: {}'.format(curr_dir_path)) click.echo('*********SCHEMA FOR {}**********'.format( pipeline_name.upper())) with open(curr_dir_path, 'r') as f: click.echo(f.read())
def _prepare_output_artifacts( self, input_artifacts: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], execution_id: int, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, List[types.Artifact]]: """Overrides BaseDriver._prepare_output_artifacts().""" del input_artifacts result = channel_utils.unwrap_channel_dict(output_dict) if len(result) != 1: raise RuntimeError('Multiple output artifacts are not supported.') base_output_dir = os.path.join(pipeline_info.pipeline_root, component_info.component_id) example_artifact = artifact_utils.get_single_instance( result[utils.EXAMPLES_KEY]) example_artifact.uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, utils.EXAMPLES_KEY, execution_id) example_artifact.set_string_custom_property( utils.FINGERPRINT_PROPERTY_NAME, exec_properties[utils.FINGERPRINT_PROPERTY_NAME]) example_artifact.set_string_custom_property( utils.SPAN_PROPERTY_NAME, exec_properties[utils.SPAN_PROPERTY_NAME]) base_driver._prepare_output_paths(example_artifact) # pylint: disable=protected-access return result
def testPrepareOutputArtifacts(self): examples = standard_artifacts.Examples() output_dict = {utils.EXAMPLES_KEY: channel_utils.as_channel([examples])} exec_properties = { utils.SPAN_PROPERTY_NAME: 2, utils.VERSION_PROPERTY_NAME: 1, utils.FINGERPRINT_PROPERTY_NAME: 'fp' } pipeline_info = data_types.PipelineInfo( pipeline_name='name', pipeline_root=self._test_dir, run_id='rid') component_info = data_types.ComponentInfo( component_type='type', component_id='cid', pipeline_info=pipeline_info) input_artifacts = {} output_artifacts = self._example_gen_driver._prepare_output_artifacts( input_artifacts, output_dict, exec_properties, 1, pipeline_info, component_info) examples = artifact_utils.get_single_instance( output_artifacts[utils.EXAMPLES_KEY]) base_output_dir = os.path.join(self._test_dir, component_info.component_id) expected_uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, 'examples', 1) self.assertEqual(examples.uri, expected_uri) self.assertEqual( examples.get_string_custom_property(utils.FINGERPRINT_PROPERTY_NAME), 'fp') self.assertEqual( examples.get_string_custom_property(utils.SPAN_PROPERTY_NAME), '2') self.assertEqual( examples.get_string_custom_property(utils.VERSION_PROPERTY_NAME), '1')
def _prepare_output_artifacts( self, input_artifacts: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], execution_id: int, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, List[types.Artifact]]: """Overrides BaseDriver._prepare_output_artifacts().""" del input_artifacts example_artifact = output_dict[utils.EXAMPLES_KEY].type() base_output_dir = os.path.join(pipeline_info.pipeline_root, component_info.component_id) example_artifact.uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, utils.EXAMPLES_KEY, execution_id) example_artifact.set_string_custom_property( utils.FINGERPRINT_PROPERTY_NAME, exec_properties[utils.FINGERPRINT_PROPERTY_NAME]) example_artifact.set_string_custom_property( utils.SPAN_PROPERTY_NAME, str(exec_properties[utils.SPAN_PROPERTY_NAME])) # TODO(b/162622803): add default behavior for when version spec not present. if exec_properties[utils.VERSION_PROPERTY_NAME]: example_artifact.set_string_custom_property( utils.VERSION_PROPERTY_NAME, str(exec_properties[utils.VERSION_PROPERTY_NAME])) base_driver._prepare_output_paths(example_artifact) # pylint: disable=protected-access return {utils.EXAMPLES_KEY: [example_artifact]}
def artifact_dir(self, artifact_root: str, artifact_subdir: str = "") -> str: """Returns the full path to the artifact subdir under the pipeline root. For example to get the transformed examples for the train split: `self.artifact_dir('Transform.AutoData/transformed_examples', 'train/*')` would return the path '<pipeline_root>/Transform.AutoData/transformed_examples/4/train/*'. Assumes there is a single execution per component. Args: artifact_root: Root subdirectory to specify the component's artifacts. artifact_subdir: Optional subdirectory to append after the execution ID. Returns: The full path to the artifact subdir under the pipeline root. """ root = os.path.join(self.pipeline_root, artifact_root) artifact_subdirs = tf.io.gfile.listdir(root) if len(artifact_subdirs) != 1: raise ValueError( f"Expected a single artifact dir, got: {artifact_subdirs}") base_uri = base_driver._generate_output_uri( # pylint: disable=protected-access self.pipeline_root, artifact_root, artifact_subdirs[0]) return os.path.join(base_uri, artifact_subdir)
def testPipelineSchemaSuccessfulRun(self): # First create a pipeline. flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_DSL_PATH: self.pipeline_path } handler = beam_handler.BeamHandler(flags_dict) handler.create_pipeline() flags_dict = { labels.ENGINE_FLAG: self.engine, labels.PIPELINE_NAME: self.pipeline_name, } handler = beam_handler.BeamHandler(flags_dict) # Create fake schema in pipeline root. component_output_dir = os.path.join(self.pipeline_root, 'SchemaGen') schema_path = base_driver._generate_output_uri( # pylint: disable=protected-access component_output_dir, 'schema', 3) tf.io.gfile.makedirs(schema_path) with open(os.path.join(schema_path, 'schema.pbtxt'), 'w') as f: f.write('SCHEMA') with self.captureWritesToStream(sys.stdout) as captured: handler.get_schema() curr_dir_path = os.path.join(os.getcwd(), 'schema.pbtxt') self.assertIn('Path to schema: {}'.format(curr_dir_path), captured.contents()) self.assertIn( '*********SCHEMA FOR {}**********'.format( self.pipeline_name.upper()), captured.contents()) self.assertTrue(tf.io.gfile.exists(curr_dir_path))
def assertInfraValidatorPassed(self) -> None: infra_validator_path = os.path.join(self._pipeline_root, 'InfraValidator') blessing_path = os.path.join(self._pipeline_root, 'InfraValidator', 'blessing') executions = tf.io.gfile.listdir(blessing_path) self.assertGreaterEqual(len(executions), 1) for exec_id in executions: blessing_uri = base_driver._generate_output_uri( # pylint: disable=protected-access infra_validator_path, 'blessing', exec_id) blessed = os.path.join(blessing_uri, 'INFRA_BLESSED') self.assertTrue(tf.io.gfile.exists(blessed))
def _prepare_output_artifacts( self, input_artifacts: Dict[Text, List[types.Artifact]], output_dict: Dict[Text, types.Channel], exec_properties: Dict[Text, Any], execution_id: int, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[Text, List[types.Artifact]]: """Overrides BaseDriver._prepare_output_artifacts().""" del input_artifacts example_artifact = output_dict[utils.EXAMPLES_KEY].type() base_output_dir = os.path.join(pipeline_info.pipeline_root, component_info.component_id) example_artifact.uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, utils.EXAMPLES_KEY, execution_id) _update_output_artifact(exec_properties, example_artifact.mlmd_artifact) base_driver._prepare_output_paths(example_artifact) # pylint: disable=protected-access return {utils.EXAMPLES_KEY: [example_artifact]}