Exemplo n.º 1
0
 def test_clean_up_pipeline_args(self):
     pipeline_args = [
         "--project=1", "--tempLocation=gs://tmp", "--jobName", "test",
         "--extra=1"
     ]
     expected = [
         "--project=1", "--temp_location=gs://tmp", "--job_name=test"
     ]
     self.assertEqual(clean_up_pipeline_args(pipeline_args), expected)
Exemplo n.º 2
0
def generate_statistics_from_tfrecord(
        pipeline_args,  # type: List[str]
        data_location,  # type: str
        output_path,  # type: str
        stats_options  # type: StatsOptions
):
    # type: (...) ->  statistics_pb2.DatasetFeatureStatisticsList
    """
    Generate stats file from a tfrecord dataset using TFDV

    :param pipeline_args: un-parsed Dataflow arguments
    :param data_location: input data dir containing tfrecord files
    :param output_path: output path for the stats file
    :return a DatasetFeatureStatisticsList proto.
    """
    assert_not_empty_string(data_location)
    assert_not_empty_string(output_path)

    args_in_snake_case = clean_up_pipeline_args(pipeline_args)
    pipeline_options = PipelineOptions(flags=args_in_snake_case)

    all_options = pipeline_options.get_all_options()

    if all_options["job_name"] is None:
        gcloud_options = pipeline_options.view_as(GoogleCloudOptions)
        gcloud_options.job_name = "generatestats-%s" % str(int(time.time()))

    if all_options["setup_file"] is None:
        setup_file_path = create_setup_file()
        setup_options = pipeline_options.view_as(SetupOptions)
        setup_options.setup_file = setup_file_path

    input_files = os.path.join(data_location, "*.tfrecords*")
    return tfdv.generate_statistics_from_tfrecord(
        data_location=input_files,
        output_path=output_path,
        stats_options=stats_options,
        pipeline_options=pipeline_options)