def test_multi_key_sorts(self): expected_ids =\ ['a_1', 'a_2', 'a_4', 'a_6', 'a_8', 'a_11', 'a_3', 'a_5', 'a_7', 'a_9', 'a_10'] expected_pids =\ ['p_1', 'p_1', 'p_1', 'p_1', 'p_1', 'p_1', 'p_2', 'p_2', 'p_2', 'p_2', 'p_2'] expected_vals1 =\ ['100', '101', '101', '102', '102', '104', '101', '102', '102', '103', '104'] expected_vals2 =\ ['100', '102', '101', '102', '104', '104', '101', '102', '103', '105', '105'] ds1 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False) ds1.sort('created_at') ds1.sort('patient_id') self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds1.index_) self.assertListEqual(expected_ids, ds1.field_by_name('id')) self.assertListEqual(expected_pids, ds1.field_by_name('patient_id')) self.assertListEqual(expected_vals1, ds1.field_by_name('created_at')) self.assertListEqual(expected_vals2, ds1.field_by_name('updated_at')) # for i in range(ds1.row_count()): # utils.print_diagnostic_row("{}".format(i), ds1, i, ds1.names_) ds2 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False) ds2.sort(('patient_id', 'created_at')) self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds2.index_) self.assertListEqual(expected_ids, ds1.field_by_name('id')) self.assertListEqual(expected_pids, ds1.field_by_name('patient_id')) self.assertListEqual(expected_vals1, ds1.field_by_name('created_at')) self.assertListEqual(expected_vals2, ds1.field_by_name('updated_at'))
def test_single_key_sorts(self): ds1 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False) ds1.sort('patient_id') self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds1.index_) ds2 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False) ds2.sort(('patient_id',)) self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds2.index_)
def test_construction_with_early_filter(self): s = io.StringIO(small_dataset) ds = csvdataset.Dataset(s, early_filter=('bar', lambda x: x in ('a',)), verbose=False) # field names and fields must match in length self.assertEqual(len(ds.names_), len(ds.fields_)) self.assertEqual(ds.row_count(), 2) self.assertEqual(ds.names_, ['id', 'patient_id', 'foo', 'bar']) expected_values = [(0, ['0aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', '11111111111111111111111111111111', '', 'a']), (2, ['02222222222222222222222222222222', '11111111111111111111111111111111', 'False', 'a'])] # value works as expected for row in range(len(expected_values)): for col in range(len(expected_values[0][1])): self.assertEqual(ds.value(row, col), expected_values[row][1][col]) # value_from_fieldname works as expected sorted_names = sorted(ds.names_) for n in sorted_names: index = ds.names_.index(n) for row in range(len(expected_values)): self.assertEqual(ds.value_from_fieldname(row, n), expected_values[row][1][index])
def test_sort(self): s = io.StringIO(small_dataset) ds = csvdataset.Dataset(s, verbose=False) ds.sort(('patient_id', 'id')) row_permutations = [2, 0, 1]
def import_with_schema(timestamp, dest_file_name, schema_file, files, overwrite, include, exclude, chunk_row_size=1 << 20): print(timestamp) print(schema_file) print(files) schema = None if isinstance(schema_file, str): with open(schema_file, encoding='utf-8') as sf: schema = load_schema(sf) elif isinstance(schema_file, StringIO): schema = load_schema(schema_file) any_parts_present = False for sk in schema.keys(): if sk in files: any_parts_present = True if not any_parts_present: raise ValueError( "none of the data sources in 'files' contain relevant data to the schema" ) # check if there's any table from the include/exclude doesn't exist in the input files input_file_tables = set(files.keys()) include_tables, exclude_tables = set(include.keys()), set(exclude.keys()) if include_tables and not include_tables.issubset(input_file_tables): extra_tables = include_tables.difference(input_file_tables) raise ValueError( "-n/--include: the following include table(s) are not part of any input files: {}" .format(extra_tables)) if exclude_tables and not exclude_tables.issubset(input_file_tables): extra_tables = exclude_tables.difference(input_file_tables) raise ValueError( "-x/--exclude: the following exclude table(s) are not part of any input files: {}" .format(extra_tables)) stop_after = {} reserved_column_names = ('j_valid_from', 'j_valid_to') datastore = per.DataStore() if overwrite: mode = 'w' else: mode = 'r+' if isinstance(dest_file_name, str) and not os.path.exists(dest_file_name): mode = 'w' with h5py.File(dest_file_name, mode) as hf: for sk in schema.keys(): if sk in reserved_column_names: msg = "{} is a reserved column name: reserved names are {}" raise ValueError(msg.format(sk, reserved_column_names)) if sk not in files: continue fields = schema[sk].fields with open(files[sk], encoding='utf-8') as f: ds = dataset.Dataset(f, stop_after=1) names = set([n.strip() for n in ds.names_]) missing_names = names.difference(fields.keys()) if len(missing_names) > 0: msg = "The following fields are present in {} but not part of the schema: {}" print("Warning:", msg.format(files[sk], missing_names)) # raise ValueError(msg.format(files[sk], missing_names)) # check if included/exclude fields are in the file include_missing_names = set(include.get(sk, [])).difference(names) if len(include_missing_names) > 0: msg = "The following include fields are not part of the {}: {}" raise ValueError(msg.format(files[sk], include_missing_names)) exclude_missing_names = set(exclude.get(sk, [])).difference(names) if len(exclude_missing_names) > 0: msg = "The following exclude fields are not part of the {}: {}" raise ValueError(msg.format(files[sk], exclude_missing_names)) for sk in schema.keys(): if sk not in files: continue fields = schema[sk].fields DatasetImporter(datastore, files[sk], hf, sk, schema[sk], timestamp, include=include, exclude=exclude, stop_after=stop_after.get(sk, None), chunk_row_size=chunk_row_size) print(sk, hf.keys()) table = hf[sk] ids = datastore.get_reader(table[list(table.keys())[0]]) jvf = datastore.get_timestamp_writer(table, 'j_valid_from') ftimestamp = utils.string_to_datetime(timestamp).timestamp() valid_froms = np.full(len(ids), ftimestamp) jvf.write(valid_froms) jvt = datastore.get_timestamp_writer(table, 'j_valid_to') valid_tos = np.full(len(ids), ops.MAX_DATETIME.timestamp()) jvt.write(valid_tos) print(hf.keys())
def split_data(patient_data, assessment_data, bucket_size=500000, territories=None): with open(patient_data) as f: p_ds = csvdataset.Dataset(f, keys=('id', 'created_at'), show_progress_every=500000) # show_progress_every=500000, stop_after=500000) p_ds.sort(('created_at', 'id')) p_ids = p_ds.field_by_name('id') p_dts = p_ds.field_by_name('created_at') # put assessment ids into buckets buckets = dict() bucket_index = 0 bucket_count = 0 for i_r in range(p_ds.row_count()): if bucket_index == bucket_size: bucket_index = 0 bucket_count += 1 buckets[p_ids[i_r]] = bucket_count bucket_index += 1 filenames = list() for b in range(bucket_count + 1): destination_filename = patient_data[:-4] + f"_{b:04d}" + ".csv" filenames.append(destination_filename) print(filenames) sorted_indices = p_ds.index_ del p_ds patient_splitter(patient_data, filenames, sorted_indices, bucket_size) print('buckets:', bucket_index) with open(assessment_data) as f: a_ds = csvdataset.Dataset(f, keys=('patient_id', 'other_symptoms'), show_progress_every=500000) print(utils.build_histogram(buckets.values())) print('associating assessments with patients') orphaned_assessments = 0 a_buckets = list() a_pids = a_ds.field_by_name('patient_id') a_os = a_ds.field_by_name('other_symptoms') for i_r in range(a_ds.row_count()): if a_pids[i_r] in buckets: a_buckets.append(buckets[a_pids[i_r]]) else: orphaned_assessments += 1 a_buckets.append(-1) del a_ds print('orphaned_assessments:', orphaned_assessments) print(f'{bucket_count + 1} buckets') for i in range(bucket_count + 1): print('bucket', i) destination_filename = assessment_data[:-4] + f"_{i:04d}" + ".csv" print(destination_filename) # with open(assessment_data) as f: # a_ds = dataset.Dataset(f, filter_fn=lambda j: a_buckets[j] == i, show_progress_every=500000) # # del a_ds assessment_splitter(assessment_data, destination_filename, a_buckets, i) print('done!')