Exemple #1
0
  def expand(self, input_or_inputs):
    tokenize_result = (input_or_inputs
      | "Read Github Dataset" >> io.Read(io.BigQuerySource(query=self.query_string,
                                                          use_standard_sql=True))
      | "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
      | "Tokenize Code/Docstring Pairs" >> beam.ParDo(TokenizeCodeDocstring())
                                               .with_outputs('err_rows', main='rows')
    )

    #pylint: disable=expression-not-assigned
    (tokenize_result.err_rows
     | "Failed Row Tokenization" >> io.WriteToBigQuery(project=self.project,
                                                        dataset=self.output_dataset,
                                                        table=self.output_table + '_failed',
                                                        schema=self.create_failed_output_schema())
    )
    # pylint: enable=expression-not-assigned


    info_result = (tokenize_result.rows
      | "Extract Function Info" >> beam.ParDo(ExtractFuncInfo(self.data_columns[2:]))
                                       .with_outputs('err_rows', main='rows')
    )

    #pylint: disable=expression-not-assigned
    (info_result.err_rows
     | "Failed Function Info" >> io.WriteToBigQuery(project=self.project,
                                                        dataset=self.output_dataset,
                                                        table=self.output_table + '_failed',
                                                        schema=self.create_failed_output_schema())
    )
    # pylint: enable=expression-not-assigned

    processed_rows = (info_result.rows | "Flatten Rows" >> beam.FlatMap(lambda x: x))

    # pylint: disable=expression-not-assigned
    (processed_rows
     | "Filter Function tokens" >> beam.Map(lambda x: x['function_tokens'])
     | "Write Function tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
                                                 file_name_suffix='.function',
                                                 num_shards=self.num_shards))
    (processed_rows
     | "Filter Docstring tokens" >> beam.Map(lambda x: x['docstring_tokens'])
     | "Write Docstring tokens" >> io.WriteToText('{}/raw_data/data'.format(self.storage_bucket),
                                                  file_name_suffix='.docstring',
                                                  num_shards=self.num_shards))
    # pylint: enable=expression-not-assigned

    return (processed_rows
      | "Save Tokens" >> io.WriteToBigQuery(project=self.project,
                                                  dataset=self.output_dataset,
                                                  table=self.output_table,
                                                  schema=self.create_output_schema())
    )
Exemple #2
0
    def expand(self, input_or_inputs):
        tokenize_result = (
            input_or_inputs
            | "Read Github Dataset" >> io.Read(
                io.BigQuerySource(query=self.query_string,
                                  use_standard_sql=True))
            | "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
            | "Tokenize Code/Docstring Pairs" >> beam.ParDo(
                TokenizeCodeDocstring()).with_outputs('err_rows', main='rows'))

        #pylint: disable=expression-not-assigned
        (tokenize_result.err_rows
         | "Failed Row Tokenization" >> io.WriteToBigQuery(
             project=self.project,
             dataset=self.output_dataset,
             table=self.output_table + '_failed',
             schema=self.create_failed_output_schema(),
             batch_size=self.batch_size))
        # pylint: enable=expression-not-assigned

        info_result = (
            tokenize_result.rows
            | "Extract Function Info" >> beam.ParDo(
                ExtractFuncInfo(self.data_columns[2:])).with_outputs(
                    'err_rows', main='rows'))

        #pylint: disable=expression-not-assigned
        (info_result.err_rows
         | "Failed Function Info" >> io.WriteToBigQuery(
             project=self.project,
             dataset=self.output_dataset,
             table=self.output_table + '_failed',
             schema=self.create_failed_output_schema(),
             batch_size=self.batch_size))
        # pylint: enable=expression-not-assigned

        return (info_result.rows
                | "Flatten Rows" >> beam.FlatMap(lambda x: x)
                | "Save Tokens" >> io.WriteToBigQuery(
                    project=self.project,
                    dataset=self.output_dataset,
                    table=self.output_table,
                    schema=self.create_output_schema(),
                    batch_size=self.batch_size))
Exemple #3
0
 def expand(self, input_or_inputs):
     return (input_or_inputs
             | "Read BigQuery Rows" >> io.Read(
                 io.BigQuerySource(query=self.query_string,
                                   use_standard_sql=True))
             | "Split 'repo_path'" >> beam.ParDo(SplitRepoPath())
             | "Tokenize Code/Docstring Pairs" >> beam.ParDo(
                 TokenizeCodeDocstring())
             | "Extract Function Info" >> beam.ParDo(
                 ExtractFuncInfo(self.data_columns[2:]))
             | "Flatten Rows" >> beam.FlatMap(lambda x: x)
             | "Write to BigQuery" >> io.WriteToBigQuery(
                 project=self.project,
                 dataset=self.output_dataset,
                 table=self.output_table,
                 schema=self.create_output_schema()))
def run(argv=None):
    """Main function.

    Main function containing the Apache Beam pipeline describing how to process
    the input CSV file to generate the LTV predictions.
    """
    parser = argparse.ArgumentParser()
    _, pipeline_args = parser.parse_known_args(argv)
    options = pipeline_options.PipelineOptions(pipeline_args)
    runtime_options = options.view_as(RuntimeOptions)

    with beam.Pipeline(options=options) as pipeline:
        options = (pipeline
                   | 'Create single element Stream containing options dict' >>
                   beam.Create([options.get_all_options()])
                   | beam.Map(
                       lambda x: {
                           k: v.get() if isinstance(
                               v, value_provider.ValueProvider) else v
                           for (k, v) in x.items()
                       })
                   | beam.Map(c.set_extra_options))

        full_elog = (
            pipeline
            | bq_mod.ReadFromBigQuery(
                project=getattr(runtime_options, c._OPTION_INPUT_BQ_PROJECT),
                query=getattr(runtime_options, c._OPTION_INPUT_BQ_QUERY),
                gcs_location=getattr(runtime_options,
                                     c._OPTION_TEMP_GCS_LOCATION),
                use_standard_sql=True)
            | beam.FlatMap(
                c.bq_row_to_list,
                pvalue.AsSingleton(options))  # (customer_id, date_str, date,
            #  sales, extra_dimension?)
        )

        full_elog_merged = (
            full_elog
            | beam.Filter(lambda x: x[3] > 0)  # sales > 0
            | beam.Map(lambda x: ((x[0], x[1]), x))  # key: (customer_id, date)
            | 'Group full elog by customer and date' >> beam.GroupByKey()
            | beam.Map(c.merge_full_elog_by_customer_and_date)  # (customer_id,
            #  date_str, date,
            #  sales)
        )

        min_max_dates = (
            full_elog_merged
            | beam.Map(lambda x: x[2])  # date
            | beam.CombineGlobally(c.MinMaxDatesFn())
            | beam.Map(c.min_max_dates_dict))

        limits_dates = (min_max_dates
                        | beam.FlatMap(c.limit_dates_boundaries,
                                       pvalue.AsSingleton(options)))

        cohort = (full_elog_merged
                  | beam.FlatMap(c.filter_customers_in_cohort,
                                 pvalue.AsSingleton(limits_dates))
                  | 'Distinct Customer IDs in Cohort' >> util.Distinct())

        cohort_count = (
            cohort
            | 'Count cohort entries' >> beam.combiners.Count.Globally())

        cohort_set = (cohort | beam.Map(lambda x: (x, 1)))

        all_customer_ids = (
            full_elog_merged
            | beam.Map(lambda x: x[0])  # key: customer_id
            | 'Distinct all Customer IDs' >> util.Distinct())

        all_customer_ids_count = (
            all_customer_ids
            | 'Count all customers' >> beam.combiners.Count.Globally())

        num_customers = (
            pipeline
            | 'Create single elem Stream I' >> beam.Create([1])
            | beam.FlatMap(c.count_customers, pvalue.AsSingleton(cohort_count),
                           pvalue.AsSingleton(all_customer_ids_count),
                           pvalue.AsSingleton(options)))

        cal_hol_elog = (full_elog_merged
                        | beam.FlatMap(c.filter_cohort_records_in_cal_hol,
                                       pvalue.AsDict(cohort_set),
                                       pvalue.AsSingleton(limits_dates)))

        cal_hol_elog_count = (
            cal_hol_elog
            | 'Count cal hol elog entries' >> beam.combiners.Count.Globally())

        calibration = (cal_hol_elog
                       | beam.FlatMap(c.filter_records_in_calibration,
                                      pvalue.AsSingleton(limits_dates)))

        num_txns_total = (
            full_elog_merged
            | beam.FlatMap(c.filter_records_in_cal_hol,
                           pvalue.AsSingleton(limits_dates))
            | 'Count num txns total' >> beam.combiners.Count.Globally())

        num_txns = (pipeline
                    | 'Create single elem Stream II' >> beam.Create([1])
                    | beam.FlatMap(c.count_txns,
                                   pvalue.AsSingleton(cal_hol_elog_count),
                                   pvalue.AsSingleton(num_txns_total),
                                   pvalue.AsSingleton(options)))

        calcbs = (
            calibration
            | beam.Map(lambda x: (x[0], x))
            | 'Group calibration elog by customer id' >> beam.GroupByKey()
            | beam.FlatMap(
                c.create_cal_cbs, pvalue.AsSingleton(options),
                pvalue.AsSingleton(limits_dates)
            )  # (customer_id, number_of_transactions, average_order_value,
            #  frequency, recency, total_time_observed)
        )

        first_transaction_dates_by_customer = (
            cal_hol_elog
            | beam.Map(lambda x: (x[0], x))  # customer_id
            | 'Group cal hol elog by customer id' >> beam.GroupByKey()
            | beam.Map(lambda x: (x[0], min(map(operator.itemgetter(2), x[1])))
                       )  # item 2 -> date
        )

        cal_hol_elog_repeat = (
            cal_hol_elog
            | beam.FlatMap(c.filter_first_transaction_date_records,
                           pvalue.AsDict(first_transaction_dates_by_customer))
            | beam.FlatMap(
                c.calculate_time_unit_numbers,  # (customer_id, date,
                #  time_unit_number)
                pvalue.AsSingleton(options),
                pvalue.AsSingleton(limits_dates))
            | beam.Map(lambda x: (x[2], 1))  # key: time_unit_number
            | 'Group cal hol elog repeat by time unit number' >>
            beam.GroupByKey()
            | beam.Map(lambda x:
                       (x[0], sum(x[1])))  # (time_unit_number, occurrences)
        )

        repeat_tx = (
            pipeline
            | 'Create single elem Stream III' >> beam.Create([1])
            | beam.FlatMap(c.calculate_cumulative_repeat_transactions,
                           pvalue.AsIter(cal_hol_elog_repeat)
                           )  # (time_unit_number, repeat_transactions,
            #  repeat_transactions_cumulative)
        )

        model_validation = (
            pipeline
            | 'Create single elem Stream IV' >> beam.Create([1])
            | beam.FlatMap(
                c.calculate_model_fit_validation, pvalue.AsSingleton(options),
                pvalue.AsSingleton(limits_dates), pvalue.AsIter(calcbs),
                pvalue.AsIter(repeat_tx), pvalue.AsSingleton(num_customers),
                pvalue.AsSingleton(num_txns)))

        _ = (model_validation | beam.Map(c.raise_error_if_invalid_mape))

        _ = (model_validation
             | beam.Map(lambda x: x[0])
             | 'Write to validation_params table' >> io.WriteToBigQuery(
                 table=c.TableValueProvider(
                     getattr(runtime_options, c._OPTION_OUTPUT_BQ_PROJECT),
                     getattr(runtime_options, c._OPTION_OUTPUT_BQ_DATASET),
                     'validation_params'),
                 custom_gcs_temp_location=getattr(runtime_options,
                                                  c._OPTION_TEMP_GCS_LOCATION),
                 validate=False,
                 schema={
                     'fields': [{
                         'name': 'calibration_start_date',
                         'type': 'STRING'
                     }, {
                         'name': 'calibration_end_date',
                         'type': 'STRING'
                     }, {
                         'name': 'cohort_end_date',
                         'type': 'STRING'
                     }, {
                         'name': 'holdout_end_date',
                         'type': 'STRING'
                     }, {
                         'name': 'model_time_granularity',
                         'type': 'STRING'
                     }, {
                         'name':
                         'model',
                         'type':
                         'RECORD',
                         'fields': [
                             {
                                 'name': 'frequency_model',
                                 'type': 'STRING'
                             },
                             {
                                 'name': 'num_customers_cohort',
                                 'type': 'INTEGER'
                             },
                             {
                                 'name': 'perc_customers_cohort',
                                 'type': 'FLOAT'
                             },
                             {
                                 'name': 'num_transactions_validation',
                                 'type': 'INTEGER'
                             },
                             {
                                 'name': 'perc_transactions_validation',
                                 'type': 'FLOAT'
                             },
                             {
                                 'name': 'mape',
                                 'type': 'FLOAT'
                             },
                         ]
                     }]
                 },
                 write_disposition=io.BigQueryDisposition.WRITE_TRUNCATE,
                 create_disposition=io.BigQueryDisposition.CREATE_IF_NEEDED))

        fullcbs_without_extra_dimension = (
            full_elog_merged
            | beam.Map(lambda x: (x[0], x))  # key: customer_id
            | 'Group full merged elog by customer id' >> beam.GroupByKey()
            | beam.FlatMap(
                c.create_fullcbs, pvalue.AsSingleton(options),
                pvalue.AsSingleton(min_max_dates)
            )  # (customer_id, number_of_transactions, historical_aov,
            #  frequency, recency, total_time_observed)
        )

        full_elog_if_extra_dimension = (
            full_elog
            | 'Discard records if no extra dimension' >> beam.FlatMap(
                c.discard_if_no_extra_dimension, pvalue.AsSingleton(options)))

        extra_dimensions_stats = (
            full_elog_if_extra_dimension
            | beam.Map(lambda x: (
                (x[0], x[4]), x))  # key: (customer_id, extra_dimension)
            | 'Group full elog by customer id and extra dimension' >>
            beam.GroupByKey()
            | beam.Map(
                c.create_extra_dimensions_stats
            )  # (customer_id, extra_dimension, dimension_count, tot_sales,
            #  max_dimension_date)
        )

        top_dimension_per_customer = (
            extra_dimensions_stats
            | beam.Map(lambda x: (x[0], x))  # customer_id
            |
            'Group extra dimension stats by customer id' >> beam.GroupByKey()
            | beam.Map(
                c.extract_top_extra_dimension
            )  # (customer_id, extra_dimension, dimension_count, tot_sales,
            #  max_dimension_date)
        )

        customer_dimension_map = (
            top_dimension_per_customer
            | beam.Map(lambda x:
                       (x[0], x[1]))  # (customer_id, extra_dimension)
        )

        fullcbs = (
            fullcbs_without_extra_dimension
            | beam.FlatMap(
                c.add_top_extra_dimension_to_fullcbs,
                pvalue.AsSingleton(options),
                pvalue.AsDict(customer_dimension_map)
            )  # (customer_id, number_of_transactions, historical_aov,
            #  frequency, recency, total_time_observed,
            #  extra_dimension?)
        )

        prediction = (
            pipeline
            | 'Create single elem Stream V' >> beam.Create([1])
            | beam.FlatMap(
                c.calculate_prediction, pvalue.AsSingleton(options),
                pvalue.AsIter(fullcbs), pvalue.AsSingleton(num_customers),
                pvalue.AsSingleton(num_txns)
            )  # [customer_id, p_alive, predicted_purchases, future_aov,
            #  historical_aov, expected_value, frequency, recency,
            #  total_time_observed, extra_dimension?], prediction_params
        )

        prediction_by_customer_no_segments = (
            prediction
            | beam.FlatMap(lambda x: x[0])  # Extract predictions by customer
        )

        _ = (
            prediction
            | beam.Map(lambda x: x[1])  # Extract prediction params
            | 'Write to prediction_params table' >> io.WriteToBigQuery(
                table=c.TableValueProvider(
                    getattr(runtime_options, c._OPTION_OUTPUT_BQ_PROJECT),
                    getattr(runtime_options, c._OPTION_OUTPUT_BQ_DATASET),
                    'prediction_params'),
                custom_gcs_temp_location=getattr(runtime_options,
                                                 c._OPTION_TEMP_GCS_LOCATION),
                validate=False,
                schema={
                    'fields': [{
                        'name': 'prediction_period',
                        'type': 'INTEGER'
                    }, {
                        'name': 'prediction_period_unit',
                        'type': 'STRING'
                    }, {
                        'name': 'model_time_granularity',
                        'type': 'STRING'
                    }, {
                        'name': 'customers_modeled',
                        'type': 'INTEGER'
                    }, {
                        'name': 'transactions_observed',
                        'type': 'INTEGER'
                    }, {
                        'name': 'frequency_model',
                        'type': 'STRING'
                    }, {
                        'name':
                        'bgnbd_model_params',
                        'type':
                        'RECORD',
                        'fields': [{
                            'name': 'a',
                            'type': 'FLOAT'
                        }, {
                            'name': 'b',
                            'type': 'FLOAT'
                        }, {
                            'name': 'r',
                            'type': 'FLOAT'
                        }, {
                            'name': 'alpha',
                            'type': 'FLOAT'
                        }]
                    }, {
                        'name':
                        'paretonbd_model_params',
                        'type':
                        'RECORD',
                        'fields': [{
                            'name': 'r',
                            'type': 'FLOAT'
                        }, {
                            'name': 's',
                            'type': 'FLOAT'
                        }, {
                            'name': 'alpha',
                            'type': 'FLOAT'
                        }, {
                            'name': 'beta',
                            'type': 'FLOAT'
                        }]
                    }, {
                        'name':
                        'gamma_gamma_params',
                        'type':
                        'RECORD',
                        'fields': [{
                            'name': 'p',
                            'type': 'FLOAT'
                        }, {
                            'name': 'q',
                            'type': 'FLOAT'
                        }, {
                            'name': 'v',
                            'type': 'FLOAT'
                        }]
                    }]
                },
                write_disposition=io.BigQueryDisposition.WRITE_TRUNCATE,
                create_disposition=io.BigQueryDisposition.CREATE_IF_NEEDED))

        num_rows = (full_elog_merged
                    | 'Count num rows in full elog merged' >>
                    beam.combiners.Count.Globally())

        segment_predictions_exact = (
            pipeline
            | 'Create single elem Stream VII' >> beam.Create([1])
            | beam.FlatMap(
                lambda _, rows_count:
                [rows_count <= c._SEGMENT_PREDICTION_THRESHOLD],
                pvalue.AsSingleton(num_rows)))

        sharded_cust_predictions_no_segments_exact, \
            sharded_cust_predictions_no_segments_hash = (
                prediction_by_customer_no_segments
                | beam.FlatMap(
                    c.prediction_sharded,
                    pvalue.AsSingleton(options),
                    pvalue.AsSingleton(segment_predictions_exact)
                )  # [customer_id, p_alive, predicted_purchases, future_aov,
                   #  historical_aov, expected_value, frequency, recency,
                   #  total_time_observed, extra_dimension?]
                | beam.Partition(lambda x, _: 0 if x[1] else 1, 2)
            )

        # BEGIN of "exact" branch
        prediction_by_customer_exact = (
            pipeline
            | 'Create single elem Stream VIII' >> beam.Create([1])
            | beam.FlatMap(
                c.split_in_ntiles_exact, pvalue.AsSingleton(options),
                pvalue.AsIter(sharded_cust_predictions_no_segments_exact
                              ))  # [customer_id, p_alive, predicted_purchases,
            #  future_aov, historical_aov, expected_value,
            #  frequency, recency, total_time_observed,
            #  segment, extra_dimension?]
        )
        # END of "exact" branch

        # BEGIN of "hash" branch
        customer_count_by_expected_value = (
            sharded_cust_predictions_no_segments_hash
            | beam.Map(lambda x: (x[0][5], 1))  # (expected_value, 1)
            | 'Group customer predictions by expected value' >>
            beam.GroupByKey()
            | beam.Map(lambda x:
                       (x[0], sum(x[1])))  # expected_value, customers_count
        )

        hash_segment_limits = (
            pipeline
            | 'Create single elem Stream IX' >> beam.Create([1])
            | beam.FlatMap(c.expected_values_segment_limits,
                           pvalue.AsSingleton(options),
                           pvalue.AsIter(customer_count_by_expected_value),
                           pvalue.AsSingleton(all_customer_ids_count)))

        prediction_by_customer_hash = (
            sharded_cust_predictions_no_segments_hash
            | beam.Map(lambda x: x[0])
            | beam.FlatMap(c.split_in_ntiles_hash,
                           pvalue.AsSingleton(hash_segment_limits)
                           )  # [customer_id, p_alive, predicted_purchases,
            #  future_aov, historical_aov, expected_value,
            #  frequency, recency, total_time_observed,
            #  segment, extra_dimension?]
        )
        # END of "hash" branch

        prediction_by_customer = (
            # only one of these two streams will contains values
            (prediction_by_customer_exact, prediction_by_customer_hash)
            | beam.Flatten()
            | beam.Map(c.clean_nan_and_inf))

        _ = (prediction_by_customer
             | beam.FlatMap(
                 lambda x, opts: [x + ['']]
                 if not opts[c._OPTION_EXTRA_DIMENSION_EXISTS] else [x],
                 pvalue.AsSingleton(options))
             | 'prediction_by_customer to Dict' >>
             beam.Map(c.list_to_dict, [
                 'customer_id', 'p_alive', 'predicted_purchases', 'future_aov',
                 'historical_aov', 'expected_value', 'frequency', 'recency',
                 'total_time_observed', 'segment', 'extra_dimension'
             ])
             | 'Write to prediction_by_customer table' >> io.WriteToBigQuery(
                 table=c.TableValueProvider(
                     getattr(runtime_options, c._OPTION_OUTPUT_BQ_PROJECT),
                     getattr(runtime_options, c._OPTION_OUTPUT_BQ_DATASET),
                     'prediction_by_customer'),
                 custom_gcs_temp_location=getattr(runtime_options,
                                                  c._OPTION_TEMP_GCS_LOCATION),
                 validate=False,
                 schema='customer_id:STRING, p_alive:FLOAT64'
                 ', predicted_purchases:FLOAT64'
                 ', future_aov:FLOAT64, historical_aov:FLOAT64'
                 ', expected_value:FLOAT64, frequency:INT64'
                 ', recency:FLOAT64'
                 ', total_time_observed:FLOAT64, segment:INT64'
                 ', extra_dimension:STRING',
                 write_disposition=io.BigQueryDisposition.WRITE_TRUNCATE,
                 create_disposition=io.BigQueryDisposition.CREATE_IF_NEEDED))

        prediction_summary_temp = (
            prediction_by_customer
            | beam.Map(lambda x: (x[9], x))  # key: segment
            | 'Group customer predictions by segment' >> beam.GroupByKey()
            | beam.FlatMap(
                c.generate_prediction_summary, pvalue.AsSingleton(
                    options))  # (segment, average_retention_probability,
            #  average_predicted_customer_value,
            #  average_predicted_order_value,
            #  average_predicted_purchases, total_customer_value,
            #  number_of_customers)
        )

        tot_equity = (
            prediction_summary_temp
            | beam.Map(lambda x: x[5])  # total_customer_value
            | beam.CombineGlobally(sum))

        prediction_summary = (
            prediction_summary_temp
            | beam.FlatMap(
                c.calculate_perc_of_total_customer_value,
                pvalue.AsSingleton(tot_equity), pvalue.AsSingleton(
                    options))  # (segment, average_retention_probability,
            #  average_predicted_customer_value,
            #  average_predicted_order_value,
            #  average_predicted_purchases,
            #  total_customer_value, number_of_customers,
            #  perc_of_total_customer_value)
        )

        _ = (
            prediction_summary
            | 'prediction_summary to Dict' >> beam.Map(c.list_to_dict, [
                'segment', 'average_retention_probability',
                'average_predicted_customer_value',
                'average_predicted_order_value', 'average_predicted_purchases',
                'total_customer_value', 'number_of_customers',
                'perc_of_total_customer_value'
            ])
            | 'Write to prediction_summary table' >> io.WriteToBigQuery(
                table=c.TableValueProvider(
                    getattr(runtime_options, c._OPTION_OUTPUT_BQ_PROJECT),
                    getattr(runtime_options, c._OPTION_OUTPUT_BQ_DATASET),
                    'prediction_summary'),
                custom_gcs_temp_location=getattr(runtime_options,
                                                 c._OPTION_TEMP_GCS_LOCATION),
                validate=False,
                schema='segment:INT64 ,average_retention_probability:FLOAT64'
                ', average_predicted_customer_value:FLOAT64'
                ', average_predicted_order_value:FLOAT64'
                ', average_predicted_purchases:FLOAT64'
                ', total_customer_value:FLOAT64'
                ', number_of_customers:FLOAT64'
                ', perc_of_total_customer_value:FLOAT64',
                write_disposition=io.BigQueryDisposition.WRITE_TRUNCATE,
                create_disposition=io.BigQueryDisposition.CREATE_IF_NEEDED))

        prediction_summary_extra_dimension = (
            prediction_by_customer
            | 'Discard prediction if there is not extra dimension' >>
            beam.FlatMap(c.discard_if_no_extra_dimension,
                         pvalue.AsSingleton(options))
            | beam.Map(lambda x: (x[10], x))  # extra dimension
            | 'Group customer predictions by extra dimension' >>
            beam.GroupByKey()
            | beam.FlatMap(c.generate_prediction_summary_extra_dimension,
                           pvalue.AsSingleton(tot_equity),
                           pvalue.AsSingleton(options)))

        _ = (prediction_summary_extra_dimension
             | 'prediction_summary_extra_dimension to Dict' >> beam.Map(
                 c.list_to_dict, [
                     'extra_dimension', 'average_retention_probability',
                     'average_predicted_customer_value',
                     'average_predicted_order_value',
                     'average_predicted_purchases', 'total_customer_value',
                     'number_of_customers', 'perc_of_total_customer_value'
                 ])
             | 'Write to prediction_summary_extra_dimension table' >>
             io.WriteToBigQuery(
                 table=c.TableValueProvider(
                     getattr(runtime_options, c._OPTION_OUTPUT_BQ_PROJECT),
                     getattr(runtime_options, c._OPTION_OUTPUT_BQ_DATASET),
                     'prediction_summary_extra_dimension'),
                 custom_gcs_temp_location=getattr(runtime_options,
                                                  c._OPTION_TEMP_GCS_LOCATION),
                 validate=False,
                 schema='extra_dimension:STRING'
                 ', average_retention_probability:FLOAT64'
                 ', average_predicted_customer_value:FLOAT64'
                 ', average_predicted_order_value:FLOAT64'
                 ', average_predicted_purchases:FLOAT64'
                 ', total_customer_value:FLOAT64'
                 ', number_of_customers:INT64'
                 ', perc_of_total_customer_value:FLOAT64',
                 write_disposition=io.BigQueryDisposition.WRITE_TRUNCATE,
                 create_disposition=io.BigQueryDisposition.CREATE_IF_NEEDED))