def test_basics(self): # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_file(os.path.join(temp_folder, 'input.txt'), FILE_CONTENTS) # Run pipeline # Avoid dependency on SciPy scipy_mock = MagicMock() result_mock = MagicMock(x=np.ones(3)) scipy_mock.optimize.minimize = MagicMock(return_value=result_mock) modules = { 'scipy': scipy_mock, 'scipy.optimize': scipy_mock.optimize } with patch.dict('sys.modules', modules): from apache_beam.examples.complete import distribopt distribopt.run([ '--input=%s/input.txt' % temp_folder, '--output', os.path.join(temp_folder, 'result')]) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: lines = result_file.readlines() # Only 1 result self.assertEqual(len(lines), 1) # parse result line and verify optimum optimum = make_tuple(lines[0]) self.assertAlmostEqual(optimum['cost'], 454.39597, places=3) self.assertDictEqual(optimum['mapping'], EXPECTED_MAPPING) production = optimum['production'] for plant in ['A', 'B', 'C']: np.testing.assert_almost_equal(production[plant], np.ones(3))
def test_basics(self): # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_file(os.path.join(temp_folder, 'input.txt'), FILE_CONTENTS) # Run pipeline # Avoid dependency on SciPy scipy_mock = MagicMock() result_mock = MagicMock(x=np.ones(3)) scipy_mock.optimize.minimize = MagicMock(return_value=result_mock) modules = {'scipy': scipy_mock, 'scipy.optimize': scipy_mock.optimize} with patch.dict('sys.modules', modules): from apache_beam.examples.complete import distribopt distribopt.run([ '--input=%s/input.txt' % temp_folder, '--output', os.path.join(temp_folder, 'result') ], save_main_session=False) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: lines = result_file.readlines() # Only 1 result self.assertEqual(len(lines), 1) # parse result line and verify optimum optimum = make_tuple(lines[0]) self.assertAlmostEqual(optimum['cost'], 454.39597, places=3) self.assertDictEqual(optimum['mapping'], EXPECTED_MAPPING) production = optimum['production'] for plant in ['A', 'B', 'C']: np.testing.assert_almost_equal(production[plant], np.ones(3))
def test_enrich(self): # Compute expected result rides = pd.concat( pd.read_csv(path) for path in glob.glob(self.input_path)) zones = pd.read_csv(self.lookup_path) rides = rides.merge(zones.set_index('LocationID').Borough, right_index=True, left_on='DOLocationID', how='left') expected_counts = rides.groupby('Borough').passenger_count.sum() taxiride.run_enrich_pipeline(beam.Pipeline(), self.input_path, self.output_path, self.lookup_path) # Parse result file and compare. # TODO(BEAM-XXXX): taxiride examples should produce int sums, not floats results = [] with open_shards(f'{self.output_path}-*') as result_file: for line in result_file: match = re.search(r'(\S+),([0-9\.]+)', line) if match is not None: results.append( (match.group(1), int(float(match.group(2))))) elif line.strip(): self.assertEqual(line.strip(), 'Borough,passenger_count') self.assertEqual(sorted(results), sorted(expected_counts.items()))
def get_results(self, temp_path): results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: match = re.search(r'([A-Za-z]+): ([0-9]+)', line) if match is not None: results.append((match.group(1), int(match.group(2)))) return results
def get_results(self, temp_path): results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: match = re.search(r'([A-Za-z]+): ([0-9]+)', line) if match is not None: results.append((match.group(1), int(match.group(2)))) return results
def assertFileEqual(self, filename, expected, encoding='utf-8'): # type: (typing.Text, list, typing.Optional[typing.Text]) -> None actual = [] with open_shards(filename, encoding=encoding) as f: for line in f: actual.append(line) self.assertEqual(expected, actual, msg='file \'{}\' has wrong content'.format(filename))
def test_Pipeline_parts(self, test_data_dir, temp_dir): source = pp.join(test_data_dir, 'input.json') messages_sink = pp.join(temp_dir, 'messages') segments_sink = pp.join(temp_dir, 'segments') expected_messages = pp.join(test_data_dir, 'expected_messages.json') expected_segments = pp.join(test_data_dir, 'expected_segments.json') with _TestPipeline() as p: messages = ( p | beam.io.ReadFromText(file_pattern=source, coder=JSONDictCoder()) | "MessagesAddKey" >> beam.Map(SegmentPipeline.groupby_fn) | "MessagesGroupByKey" >> beam.GroupByKey()) segments = p | beam.Create([]) segmented = messages | Segment(segments) messages = segmented[Segment.OUTPUT_TAG_MESSAGES] (messages | "WriteToMessagesSink" >> beam.io.WriteToText( file_path_prefix=messages_sink, num_shards=1, coder=JSONDictCoder())) segments = segmented[Segment.OUTPUT_TAG_SEGMENTS] (segments | "WriteToSegmentsSink" >> beam.io.WriteToText( file_path_prefix=segments_sink, num_shards=1, coder=JSONDictCoder())) p.run() with nlj.open(expected_messages) as expected: with open_shards('%s*' % messages_sink) as output: assert sorted(expected) == sorted(nlj.load(output)) with nlj.open(expected_segments) as expected_output: with open_shards('%s*' % segments_sink) as actual_output: for expected, actual in zip( sorted(expected_output, key=lambda x: x['seg_id']), sorted(nlj.load(actual_output), key=lambda x: x['seg_id'])): assert set(expected.items()).issubset( set(actual.items()))
def test_gcp_sink(self, temp_dir): messages = list(MessageGenerator().messages()) dest = pp.join(temp_dir, 'messages.json') with _TestPipeline() as p: (p | beam.Create(messages) | GCPSink(dest)) p.run() with open_shards('%s*' % dest) as output: assert sorted(messages) == sorted(nlj.load(output))
def test_gcp_sink(self, temp_dir): messages = list(MessageGenerator().messages()) dest = pp.join(temp_dir, 'messages.json') with _TestPipeline() as p: (p | beam.Create(messages) | GCPSink(dest)) p.run() with open_shards('%s*' % dest) as output: assert (sorted(messages, key=lambda x: x[b'timestamp']) == sorted( [fix_keys(d) for d in nlj.load(output)], key=lambda x: x[b'timestamp']))
def test_estimate_pi_output_file(self): test_pipeline = TestPipeline(is_integration_test=True) temp_folder = tempfile.mkdtemp() extra_opts = {'output': os.path.join(temp_folder, 'result')} estimate_pi.run(test_pipeline.get_full_options_as_args(**extra_opts)) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: [_, _, estimated_pi] = json.loads(result_file.read().strip()) # Note: Probabilistically speaking this test can fail with a probability # that is very small (VERY) given that we run at least 100 thousand # trials. self.assertTrue(3.125 <= estimated_pi <= 3.155)
def _run_pipeline(self, tag_field, tag_value, dest, expected, args=[]): args += [ '--tag_field=%s' % tag_field, '--tag_value=%s' % tag_value, '--dest=%s' % dest, '--wait' ] pipe_template.__main__.run(args) with open_shards('%s*' % dest) as output: assert sorted(expected, key=lambda x: x['idx']) == sorted( nlj.load(output), key=lambda x: x['idx'])
def test_basics(self): temp_path = self.create_temp_file(self.SAMPLE_DATA) expected_words = [("abc", 100), ("rothko", 750)] AvgPrice.run( ["--input=%s*" % temp_path, "--output=%s.result" % temp_path]) results = [] with open_shards(temp_path + ".result-*-of-*") as result_file: for line in result_file: match = re.search(r"([a-zA-Z]+): ([0-9]+)", line) if match is not None: results.append((match.group(1), int(match.group(2)))) self.assertEqual(sorted(results), sorted(expected_words))
def main_test(self): mn.run([ '--input=%s*' % self.test_data_path, '--output=%s.result' % self.test_data_path ]) # Parse result file and compare. results = [] with open_shards(self.test_data_path + '.result-*-of-*') as result_file: for line in result_file: match = re.search(r'([a-z]+): ([0-9]+)', line) if match is not None: results.append((match.group(1), int(match.group(2)))) self.assertEqual(1, 1)
def _run_segment(self, messages_in, segments_in, temp_dir): messages_file = pp.join(temp_dir, '_run_segment', 'messages') segments_file = pp.join(temp_dir, '_run_segment', 'segments') with _TestPipeline() as p: messages = ( p | 'CreateMessages' >> beam.Create(messages_in) | 'AddKeyMessages' >> beam.Map(self.groupby_fn) | "MessagesGroupByKey" >> beam.GroupByKey() ) segments = ( p | 'CreateSegments' >> beam.Create(segments_in) | 'AddKeySegments' >> beam.Map(self.groupby_fn) | "SegmentsGroupByKey" >> beam.GroupByKey() ) segmented = ( messages | "Segment" >> Segment(segments) ) messages = segmented['messages'] segments = segmented[Segment.OUTPUT_TAG_SEGMENTS] messages | "WriteMessages" >> beam.io.WriteToText( messages_file, coder=JSONDictCoder()) segments | "WriteSegments" >> beam.io.WriteToText( segments_file, coder=JSONDictCoder()) p.run() with open_shards('%s*' % messages_file) as output: messages = sorted(list(nlj.load(output)), key=lambda m: (m['ssvid'], m['timestamp'])) with open_shards('%s*' % segments_file) as output: segments = list(nlj.load(output)) assert list_contains(messages, messages_in) return messages, segments
def test_basics(self): temp_path = self.create_temp_file(self.SAMPLE_TEXT) expected_words = collections.defaultdict(int) for word in re.findall(r'[\w\']+', self.SAMPLE_TEXT, re.UNICODE): expected_words[word] += 1 wordcount.run(['--input=%s*' % temp_path, '--output=%s.result' % temp_path], save_main_session=False) # Parse result file and compare. results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: match = re.search(r'(\S+): ([0-9]+)', line) if match is not None: results.append((match.group(1), int(match.group(2)))) self.assertEqual(sorted(results), sorted(expected_words.items()))
def test_basics(self): temp_path = self.create_temp_file(self.SAMPLE_TEXT) expected_words = collections.defaultdict(int) for word in re.findall(r'\w+', self.SAMPLE_TEXT): expected_words[word] += 1 wordcount_minimal.run([ '--input=%s*' % temp_path, '--output=%s.result' % temp_path]) # Parse result file and compare. results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: match = re.search(r'([a-z]+): ([0-9]+)', line) if match is not None: results.append((match.group(1), int(match.group(2)))) self.assertEqual(sorted(results), sorted(expected_words.items()))
def test_basics(self): temp_path = self.create_temp_file(self.SAMPLE_TEXT) expected_words = collections.defaultdict(int) for word in re.findall(r'\w+', self.SAMPLE_TEXT): expected_words[word] += 1 wordcount_minimal.run([ '--input=%s*' % temp_path, '--output=%s.result' % temp_path]) # Parse result file and compare. results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: match = re.search(r'([a-z]+): ([0-9]+)', line.decode('utf-8')) if match is not None: results.append((match.group(1), int(match.group(2)))) self.assertEqual(sorted(results), sorted(expected_words.items()))
def test_mergecontacts(self): path_email = self.create_temp_file(self.CONTACTS_EMAIL) path_phone = self.create_temp_file(self.CONTACTS_PHONE) path_snailmail = self.create_temp_file(self.CONTACTS_SNAILMAIL) result_prefix = self.create_temp_file('') mergecontacts.run([ '--input_email=%s' % path_email, '--input_phone=%s' % path_phone, '--input_snailmail=%s' % path_snailmail, '--output_tsv=%s.tsv' % result_prefix, '--output_stats=%s.stats' % result_prefix], assert_results=(2, 1, 3)) with open_shards('%s.tsv-*-of-*' % result_prefix) as f: contents = f.read() self.assertEqual(self.EXPECTED_TSV, self.normalize_tsv_results(contents))
def _run_pipeline(self, source, messages_sink, segments_sink, expected, args=[]): args += [ '--source=%s' % source, '--source_schema={"fields": []}', '--dest=%s' % messages_sink, '--segments=%s' % segments_sink, '--wait' ] pipe_segment_run(args) with nlj.open(expected) as expected: with open_shards('%s*' % messages_sink) as output: assert sorted(expected) == sorted(nlj.load(output))
def test_mergecontacts(self): path_email = self.create_temp_file(self.CONTACTS_EMAIL) path_phone = self.create_temp_file(self.CONTACTS_PHONE) path_snailmail = self.create_temp_file(self.CONTACTS_SNAILMAIL) result_prefix = self.create_temp_file('') mergecontacts.run([ '--input_email=%s' % path_email, '--input_phone=%s' % path_phone, '--input_snailmail=%s' % path_snailmail, '--output_tsv=%s.tsv' % result_prefix, '--output_stats=%s.stats' % result_prefix], assert_results=(2, 1, 3)) with open_shards('%s.tsv-*-of-*' % result_prefix) as f: contents = f.read() self.assertEqual(self.EXPECTED_TSV, self.normalize_tsv_results(contents))
def test_custom_ptransform_output_files_on_small_input(self): EXPECTED_RESULT = "('CAT DOG CAT CAT DOG', 2)" # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_content_input_file(os.path.join(temp_folder, 'input.txt'), ' '.join(self.WORDS)) custom_ptransform.run([ '--input=%s/input.txt' % temp_folder, '--output', os.path.join(temp_folder, 'result') ]) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: result = result_file.read().strip() self.assertEqual(result, EXPECTED_RESULT)
def test_output_file_format(self): grid_size = 5 self.run_example(grid_size) # Parse the results from the file, and ensure it was written in the proper # format. with open_shards(self.test_files['output_coord_file_name'] + '-*-of-*') as result_file: output_lines = result_file.readlines() # Should have a line for each x-coordinate. self.assertEqual(grid_size, len(output_lines)) for line in output_lines: coordinates = re.findall(r'(\(\d+, \d+, \d+\))', line) # Should have 5 coordinates on each line. self.assertTrue(coordinates) self.assertEqual(grid_size, len(coordinates))
def test_output_file_format(self): grid_size = 5 self.run_example(grid_size) # Parse the results from the file, and ensure it was written in the proper # format. with open_shards(self.test_files['output_coord_file_name'] + '-*-of-*') as result_file: output_lines = result_file.readlines() # Should have a line for each x-coordinate. self.assertEqual(grid_size, len(output_lines)) for line in output_lines: coordinates = re.findall(r'(\(\d+, \d+, \d+\))', line) # Should have 5 coordinates on each line. self.assertTrue(coordinates) self.assertEqual(grid_size, len(coordinates))
def test_basics(self): temp_path = self.create_temp_file(self.SAMPLE_TEXT) expected_words = collections.defaultdict(int) for word in re.findall(r'[\w]+', self.SAMPLE_TEXT): expected_words[word] += 1 wordcount_dataframe.run( ['--input=%s*' % temp_path, '--output=%s.result' % temp_path]) # Parse result file and compare. results = [] with open_shards(temp_path + '.result-*') as result_file: for line in result_file: match = re.search(r'(\S+),([0-9]+)', line) if match is not None: results.append((match.group(1), int(match.group(2)))) elif line.strip(): self.assertEqual(line.strip(), 'word,count') self.assertEqual(sorted(results), sorted(expected_words.items()))
def test_multiple_output_pardo(self): temp_path = self.create_temp_file(self.SAMPLE_TEXT) result_prefix = temp_path + '.result' multiple_output_pardo.run([ '--input=%s*' % temp_path, '--output=%s' % result_prefix]) expected_char_count = len(''.join(self.SAMPLE_TEXT.split('\n'))) with open_shards(result_prefix + '-chars-*-of-*') as f: contents = f.read() self.assertEqual(expected_char_count, int(contents)) short_words = self.get_wordcount_results( result_prefix + '-short-words-*-of-*') self.assertEqual(sorted(short_words), sorted(self.EXPECTED_SHORT_WORDS)) words = self.get_wordcount_results(result_prefix + '-words-*-of-*') self.assertEqual(sorted(words), sorted(self.EXPECTED_WORDS))
def test_aggregation(self): # Compute expected result rides = pd.read_csv(self.input_path) expected_counts = rides.groupby('DOLocationID').passenger_count.sum() taxiride.run_aggregation_pipeline( beam.Pipeline(), self.input_path, self.output_path) # Parse result file and compare. # TODO(BEAM-12379): taxiride examples should produce int sums, not floats results = [] with open_shards(f'{self.output_path}-*') as result_file: for line in result_file: match = re.search(r'(\S+),([0-9\.]+)', line) if match is not None: results.append((int(match.group(1)), int(float(match.group(2))))) elif line.strip(): self.assertEqual(line.strip(), 'DOLocationID,passenger_count') self.assertEqual(sorted(results), sorted(expected_counts.items()))
def test_basics_with_type_check(self): # Run the workflow with pipeline_type_check option. This will make sure # the typehints associated with all transforms will have non-default values # and therefore any custom coders will be used. In our case we want to make # sure the coder for the Player class will be used. temp_path = self.create_temp_file(self.SAMPLE_RECORDS) group_with_coder.run([ '--input=%s*' % temp_path, '--output=%s.result' % temp_path]) # Parse result file and compare. results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: name, points = line.split(',') results.append((name, int(points))) logging.info('result: %s', results) self.assertEqual( sorted(results), sorted([('x:ann', 15), ('x:fred', 9), ('x:joe', 60), ('x:mary', 8)]))
def test_multiple_output_pardo(self): temp_path = self.create_temp_file(self.SAMPLE_TEXT) result_prefix = temp_path + '.result' multiple_output_pardo.run( ['--input=%s*' % temp_path, '--output=%s' % result_prefix], save_main_session=False) expected_char_count = len(''.join(self.SAMPLE_TEXT.split('\n'))) with open_shards(result_prefix + '-chars-*-of-*') as f: contents = f.read() self.assertEqual(expected_char_count, int(contents)) short_words = self.get_wordcount_results( result_prefix + '-short-words-*-of-*') self.assertEqual(sorted(short_words), sorted(self.EXPECTED_SHORT_WORDS)) words = self.get_wordcount_results(result_prefix + '-words-*-of-*') self.assertEqual(sorted(words), sorted(self.EXPECTED_WORDS))
def test_top_wikipedia_sessions_output_files_on_small_input(self): test_pipeline = TestPipeline(is_integration_test=True) # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_content_input_file( os.path.join(temp_folder, 'input.txt'), '\n'.join(self.EDITS)) extra_opts = { 'input': '%s/input.txt' % temp_folder, 'output': os.path.join(temp_folder, 'result'), 'sampling_threshold': '1.0' } top_wikipedia_sessions.run( test_pipeline.get_full_options_as_args(**extra_opts)) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: result = result_file.read().strip().splitlines() self.assertEqual(self.EXPECTED, sorted(result, key=lambda x: x.split()[0]))
def test_basics_without_type_check(self): # Run the workflow without pipeline_type_check option. This will make sure # the typehints associated with all transforms will have default values and # therefore any custom coders will not be used. The default coder (pickler) # will be used instead. temp_path = self.create_temp_file(self.SAMPLE_RECORDS) group_with_coder.run([ '--no_pipeline_type_check', '--input=%s*' % temp_path, '--output=%s.result' % temp_path]) # Parse result file and compare. results = [] with open_shards(temp_path + '.result-*-of-*') as result_file: for line in result_file: name, points = line.split(',') results.append((name, int(points))) logging.info('result: %s', results) self.assertEqual( sorted(results), sorted([('ann', 15), ('fred', 9), ('joe', 60), ('mary', 8)]))
def test_coders_output_files_on_small_input(self): test_pipeline = TestPipeline(is_integration_test=True) # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_content_input_file( os.path.join(temp_folder, 'input.txt'), '\n'.join(map(json.dumps, self.SAMPLE_RECORDS))) extra_opts = { 'input': '%s/input.txt' % temp_folder, 'output': os.path.join(temp_folder, 'result') } coders.run(test_pipeline.get_full_options_as_args(**extra_opts)) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: result = result_file.read().strip() self.assertEqual(sorted(self.EXPECTED_RESULT), sorted(self.format_result(result)))
def test_autocomplete_output_files_on_small_input(self): logging.error('SAVE_MAIN_SESSION') test_pipeline = TestPipeline(is_integration_test=True) # Setup the files with expected content. temp_folder = tempfile.mkdtemp() create_content_input_file(os.path.join(temp_folder, 'input.txt'), ' '.join(self.WORDS)) extra_opts = { 'input': '%s/input.txt' % temp_folder, 'output': os.path.join(temp_folder, 'result') } autocomplete.run(test_pipeline.get_full_options_as_args(**extra_opts)) # Load result file and compare. with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: result = result_file.read().strip() self.assertEqual(sorted(self.EXPECTED_PREFIXES), sorted(format_output_file(result)))
def test_basics(self): # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_file(os.path.join(temp_folder, '1.txt'), 'abc def ghi') self.create_file(os.path.join(temp_folder, '2.txt'), 'abc def') self.create_file(os.path.join(temp_folder, '3.txt'), 'abc') tfidf.run([ '--uris=%s/*' % temp_folder, '--output', os.path.join(temp_folder, 'result')]) # Parse result file and compare. results = [] with open_shards(os.path.join( temp_folder, 'result-*-of-*')) as result_file: for line in result_file: match = re.search(EXPECTED_LINE_RE, line) logging.info('Result line: %s', line) if match is not None: results.append( (match.group(1), match.group(2), float(match.group(3)))) logging.info('Computed results: %s', set(results)) self.assertEqual(set(results), EXPECTED_RESULTS)
def test_basics(self): # Setup the files with expected content. temp_folder = tempfile.mkdtemp() self.create_file(os.path.join(temp_folder, '1.txt'), 'abc def ghi') self.create_file(os.path.join(temp_folder, '2.txt'), 'abc def') self.create_file(os.path.join(temp_folder, '3.txt'), 'abc') tfidf.run([ '--uris=%s/*' % temp_folder, '--output', os.path.join(temp_folder, 'result') ]) # Parse result file and compare. results = [] with open_shards(os.path.join(temp_folder, 'result-*-of-*')) as result_file: for line in result_file: match = re.search(EXPECTED_LINE_RE, line) logging.info('Result line: %s', line) if match is not None: results.append((match.group(1), match.group(2), float(match.group(3)))) logging.info('Computed results: %s', set(results)) self.assertEqual(set(results), EXPECTED_RESULTS)
def test_run_tft_pipeline(input_args): run_tft_pipeline(input_args) with open_shards(input_args.nitems_filename + '-*-of-*') as result_file: data = result_file.read().strip() assert data == '3' skus_mapping = os.path.join(input_args.tft_transform, 'transform_fn', 'assets', 'skus_mapping') with open(skus_mapping) as f: skus = f.read().split('\n') assert 'sku0' in skus assert 'sku1' in skus assert 'sku2' in skus inv_skus_mapping = dict(enumerate(skus)) skus_mapping = {v: int(k) for k, v in inv_skus_mapping.iteritems()} customers_mapping = os.path.join(input_args.tft_transform, 'transform_fn', 'assets', 'customers_mapping') with open(customers_mapping) as f: customers = f.read().split('\n') assert 'customer0' in customers assert 'customer2' in customers assert 'customer4' in customers inv_customers_mapping = dict(enumerate(customers)) customers_mapping = { v: int(k) for k, v in inv_customers_mapping.iteritems() } examples = [] for record in tf.python_io.tf_record_iterator( path=input_args.output_train_filename + '-00000-of-00001'): example = tf.train.Example() example.ParseFromString(record) examples.append(example) train_expected = { 'customer0': { 'skus_list': ['sku0', 'sku1', 'sku2'], 'actions_list': ['Browsed', 'Browsed', 'Browsed'] }, 'customer2': { 'skus_list': ['sku0', 'sku0', 'sku2'], 'actions_list': ['Browsed', 'AddedToBasket', 'Browsed'] }, 'customer4': { 'skus_list': ['sku1', 'sku2'], 'actions_list': ['Browsed', 'Browsed'] } } for example in examples: feature = dict( example.features.feature) # We are processing protobuf messages customer_id = feature['customer_id'].int64_list.value[0] customer = inv_customers_mapping[customer_id] skus_list = map(lambda x: inv_skus_mapping[x], feature['skus_list'].int64_list.value) expected_skus_list = train_expected[customer]['skus_list'] assert sorted(skus_list) == sorted(expected_skus_list) actions_list = feature['actions_list'].bytes_list.value expected_actions_list = train_expected[customer]['actions_list'] assert sorted(actions_list) == sorted(expected_actions_list) # test dataset examples = [] for record in tf.python_io.tf_record_iterator( path=input_args.output_test_filename + '-00000-of-00001'): example = tf.train.Example() example.ParseFromString(record) examples.append(example) test_expected = { 'customer0': { 'skus_list': ['sku1'], 'actions_list': ['Browsed'] }, 'customer2': { 'skus_list': ['sku0', 'sku0'], 'actions_list': ['Browsed', 'AddedToBasket'] } } for example in examples: feature = dict( example.features.feature) # We are processing protobuf messages customer_id = feature['customer_id'].int64_list.value[0] customer = inv_customers_mapping[customer_id] skus_list = map(lambda x: inv_skus_mapping[x], feature['skus_list'].int64_list.value) trained_skus_list = map(lambda x: inv_skus_mapping[x], feature['trained_skus_list'].int64_list.value) expected_skus_list = test_expected[customer]['skus_list'] trained_expected_skus_list = train_expected[customer]['skus_list'] assert sorted(skus_list) == sorted(expected_skus_list) assert sorted(trained_skus_list) == sorted(trained_expected_skus_list) actions_list = feature['actions_list'].bytes_list.value trained_actions_list = feature['trained_actions_list'].bytes_list.value expected_actions_list = test_expected[customer]['actions_list'] trained_expected_actions_list = train_expected[customer][ 'actions_list'] assert sorted(actions_list) == sorted(expected_actions_list) assert sorted(trained_actions_list) == sorted( trained_expected_actions_list)