def test_clean_up_pipeline_args(self): pipeline_args = [ "--project=1", "--tempLocation=gs://tmp", "--jobName", "test", "--extra=1" ] expected = [ "--project=1", "--temp_location=gs://tmp", "--job_name=test" ] self.assertEqual(clean_up_pipeline_args(pipeline_args), expected)
def generate_statistics_from_tfrecord( pipeline_args, # type: List[str] data_location, # type: str output_path, # type: str stats_options # type: StatsOptions ): # type: (...) -> statistics_pb2.DatasetFeatureStatisticsList """ Generate stats file from a tfrecord dataset using TFDV :param pipeline_args: un-parsed Dataflow arguments :param data_location: input data dir containing tfrecord files :param output_path: output path for the stats file :return a DatasetFeatureStatisticsList proto. """ assert_not_empty_string(data_location) assert_not_empty_string(output_path) args_in_snake_case = clean_up_pipeline_args(pipeline_args) pipeline_options = PipelineOptions(flags=args_in_snake_case) all_options = pipeline_options.get_all_options() if all_options["job_name"] is None: gcloud_options = pipeline_options.view_as(GoogleCloudOptions) gcloud_options.job_name = "generatestats-%s" % str(int(time.time())) if all_options["setup_file"] is None: setup_file_path = create_setup_file() setup_options = pipeline_options.view_as(SetupOptions) setup_options.setup_file = setup_file_path input_files = os.path.join(data_location, "*.tfrecords*") return tfdv.generate_statistics_from_tfrecord( data_location=input_files, output_path=output_path, stats_options=stats_options, pipeline_options=pipeline_options)