def test_enrich(self):
        # Compute expected result
        rides = pd.concat(
            pd.read_csv(path) for path in glob.glob(self.input_path))
        zones = pd.read_csv(self.lookup_path)
        rides = rides.merge(zones.set_index('LocationID').Borough,
                            right_index=True,
                            left_on='DOLocationID',
                            how='left')
        expected_counts = rides.groupby('Borough').passenger_count.sum()

        taxiride.run_enrich_pipeline(beam.Pipeline(), self.input_path,
                                     self.output_path, self.lookup_path)

        # Parse result file and compare.
        # TODO(BEAM-XXXX): taxiride examples should produce int sums, not floats
        results = []
        with open_shards(f'{self.output_path}-*') as result_file:
            for line in result_file:
                match = re.search(r'(\S+),([0-9\.]+)', line)
                if match is not None:
                    results.append(
                        (match.group(1), int(float(match.group(2)))))
                elif line.strip():
                    self.assertEqual(line.strip(), 'Borough,passenger_count')
        self.assertEqual(sorted(results), sorted(expected_counts.items()))
Beispiel #2
0
  def test_enrich(self):
    # Standard workers OOM with the enrich pipeline
    self.test_pipeline.get_pipeline_options().view_as(
        WorkerOptions).machine_type = 'e2-highmem-2'

    taxiride.run_enrich_pipeline(
        self.test_pipeline,
        'gs://apache-beam-samples/nyc_taxi/2018/*.csv',
        self.output_path)

    # Verify
    expected = pd.read_csv(
        os.path.join(
            os.path.dirname(__file__), 'data',
            'taxiride_2018_enrich_truth.csv'),
        comment='#')
    expected = expected.sort_values('Borough').reset_index(drop=True)

    def read_csv(path):
      with FileSystems.open(path) as fp:
        return pd.read_csv(fp)

    result = pd.concat(
        read_csv(metadata.path) for metadata in FileSystems.match(
            [f'{self.output_path}*'])[0].metadata_list)
    result = result.sort_values('Borough').reset_index(drop=True)

    pd.testing.assert_frame_equal(expected, result)