コード例 #1
0
    def test_multi_key_sorts(self):
        expected_ids =\
            ['a_1', 'a_2', 'a_4', 'a_6', 'a_8', 'a_11', 'a_3', 'a_5', 'a_7', 'a_9', 'a_10']
        expected_pids =\
            ['p_1', 'p_1', 'p_1', 'p_1', 'p_1', 'p_1', 'p_2', 'p_2', 'p_2', 'p_2', 'p_2']
        expected_vals1 =\
            ['100', '101', '101', '102', '102', '104', '101', '102', '102', '103', '104']
        expected_vals2 =\
            ['100', '102', '101', '102', '104', '104', '101', '102', '103', '105', '105']

        ds1 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False)
        ds1.sort('created_at')
        ds1.sort('patient_id')
        self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds1.index_)
        self.assertListEqual(expected_ids, ds1.field_by_name('id'))
        self.assertListEqual(expected_pids, ds1.field_by_name('patient_id'))
        self.assertListEqual(expected_vals1, ds1.field_by_name('created_at'))
        self.assertListEqual(expected_vals2, ds1.field_by_name('updated_at'))
        # for i in range(ds1.row_count()):
        #     utils.print_diagnostic_row("{}".format(i), ds1, i, ds1.names_)

        ds2 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False)
        ds2.sort(('patient_id', 'created_at'))
        self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds2.index_)
        self.assertListEqual(expected_ids, ds1.field_by_name('id'))
        self.assertListEqual(expected_pids, ds1.field_by_name('patient_id'))
        self.assertListEqual(expected_vals1, ds1.field_by_name('created_at'))
        self.assertListEqual(expected_vals2, ds1.field_by_name('updated_at'))
コード例 #2
0
    def test_single_key_sorts(self):
        ds1 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False)
        ds1.sort('patient_id')
        self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds1.index_)

        ds2 = csvdataset.Dataset(io.StringIO(sorting_dataset), verbose=False)
        ds2.sort(('patient_id',))
        self.assertListEqual([0, 1, 3, 5, 7, 10, 2, 4, 6, 8, 9], ds2.index_)
コード例 #3
0
    def test_construction_with_early_filter(self):
        s = io.StringIO(small_dataset)
        ds = csvdataset.Dataset(s, early_filter=('bar', lambda x: x in ('a',)), verbose=False)

        # field names and fields must match in length
        self.assertEqual(len(ds.names_), len(ds.fields_))

        self.assertEqual(ds.row_count(), 2)

        self.assertEqual(ds.names_, ['id', 'patient_id', 'foo', 'bar'])

        expected_values = [(0, ['0aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', '11111111111111111111111111111111', '', 'a']),
                           (2, ['02222222222222222222222222222222', '11111111111111111111111111111111', 'False', 'a'])]

        # value works as expected
        for row in range(len(expected_values)):
            for col in range(len(expected_values[0][1])):
                self.assertEqual(ds.value(row, col), expected_values[row][1][col])

        # value_from_fieldname works as expected
        sorted_names = sorted(ds.names_)
        for n in sorted_names:
            index = ds.names_.index(n)
            for row in range(len(expected_values)):
                self.assertEqual(ds.value_from_fieldname(row, n), expected_values[row][1][index])
コード例 #4
0
    def test_sort(self):
        s = io.StringIO(small_dataset)
        ds = csvdataset.Dataset(s, verbose=False)

        ds.sort(('patient_id', 'id'))
        row_permutations = [2, 0, 1]
コード例 #5
0
def import_with_schema(timestamp,
                       dest_file_name,
                       schema_file,
                       files,
                       overwrite,
                       include,
                       exclude,
                       chunk_row_size=1 << 20):

    print(timestamp)
    print(schema_file)
    print(files)

    schema = None
    if isinstance(schema_file, str):
        with open(schema_file, encoding='utf-8') as sf:
            schema = load_schema(sf)
    elif isinstance(schema_file, StringIO):
        schema = load_schema(schema_file)

    any_parts_present = False
    for sk in schema.keys():
        if sk in files:
            any_parts_present = True
    if not any_parts_present:
        raise ValueError(
            "none of the data sources in 'files' contain relevant data to the schema"
        )

    # check if there's any table from the include/exclude doesn't exist in the input files
    input_file_tables = set(files.keys())
    include_tables, exclude_tables = set(include.keys()), set(exclude.keys())
    if include_tables and not include_tables.issubset(input_file_tables):
        extra_tables = include_tables.difference(input_file_tables)
        raise ValueError(
            "-n/--include: the following include table(s) are not part of any input files: {}"
            .format(extra_tables))

    if exclude_tables and not exclude_tables.issubset(input_file_tables):
        extra_tables = exclude_tables.difference(input_file_tables)
        raise ValueError(
            "-x/--exclude: the following exclude table(s) are not part of any input files: {}"
            .format(extra_tables))

    stop_after = {}
    reserved_column_names = ('j_valid_from', 'j_valid_to')
    datastore = per.DataStore()

    if overwrite:
        mode = 'w'
    else:
        mode = 'r+'

    if isinstance(dest_file_name, str) and not os.path.exists(dest_file_name):
        mode = 'w'

    with h5py.File(dest_file_name, mode) as hf:
        for sk in schema.keys():
            if sk in reserved_column_names:
                msg = "{} is a reserved column name: reserved names are {}"
                raise ValueError(msg.format(sk, reserved_column_names))

            if sk not in files:
                continue

            fields = schema[sk].fields

            with open(files[sk], encoding='utf-8') as f:
                ds = dataset.Dataset(f, stop_after=1)
            names = set([n.strip() for n in ds.names_])
            missing_names = names.difference(fields.keys())
            if len(missing_names) > 0:
                msg = "The following fields are present in {} but not part of the schema: {}"
                print("Warning:", msg.format(files[sk], missing_names))
                # raise ValueError(msg.format(files[sk], missing_names))

            # check if included/exclude fields are in the file
            include_missing_names = set(include.get(sk, [])).difference(names)
            if len(include_missing_names) > 0:
                msg = "The following include fields are not part of the {}: {}"
                raise ValueError(msg.format(files[sk], include_missing_names))

            exclude_missing_names = set(exclude.get(sk, [])).difference(names)
            if len(exclude_missing_names) > 0:
                msg = "The following exclude fields are not part of the {}: {}"
                raise ValueError(msg.format(files[sk], exclude_missing_names))

        for sk in schema.keys():
            if sk not in files:
                continue

            fields = schema[sk].fields

            DatasetImporter(datastore,
                            files[sk],
                            hf,
                            sk,
                            schema[sk],
                            timestamp,
                            include=include,
                            exclude=exclude,
                            stop_after=stop_after.get(sk, None),
                            chunk_row_size=chunk_row_size)

            print(sk, hf.keys())
            table = hf[sk]
            ids = datastore.get_reader(table[list(table.keys())[0]])
            jvf = datastore.get_timestamp_writer(table, 'j_valid_from')
            ftimestamp = utils.string_to_datetime(timestamp).timestamp()
            valid_froms = np.full(len(ids), ftimestamp)
            jvf.write(valid_froms)
            jvt = datastore.get_timestamp_writer(table, 'j_valid_to')
            valid_tos = np.full(len(ids), ops.MAX_DATETIME.timestamp())
            jvt.write(valid_tos)

        print(hf.keys())
コード例 #6
0
def split_data(patient_data,
               assessment_data,
               bucket_size=500000,
               territories=None):

    with open(patient_data) as f:
        p_ds = csvdataset.Dataset(f,
                                  keys=('id', 'created_at'),
                                  show_progress_every=500000)
        # show_progress_every=500000, stop_after=500000)
        p_ds.sort(('created_at', 'id'))
        p_ids = p_ds.field_by_name('id')
        p_dts = p_ds.field_by_name('created_at')

    # put assessment ids into buckets
    buckets = dict()
    bucket_index = 0
    bucket_count = 0
    for i_r in range(p_ds.row_count()):
        if bucket_index == bucket_size:
            bucket_index = 0
            bucket_count += 1
        buckets[p_ids[i_r]] = bucket_count
        bucket_index += 1

    filenames = list()
    for b in range(bucket_count + 1):
        destination_filename = patient_data[:-4] + f"_{b:04d}" + ".csv"
        filenames.append(destination_filename)
    print(filenames)
    sorted_indices = p_ds.index_
    del p_ds

    patient_splitter(patient_data, filenames, sorted_indices, bucket_size)

    print('buckets:', bucket_index)
    with open(assessment_data) as f:
        a_ds = csvdataset.Dataset(f,
                                  keys=('patient_id', 'other_symptoms'),
                                  show_progress_every=500000)

    print(utils.build_histogram(buckets.values()))

    print('associating assessments with patients')
    orphaned_assessments = 0
    a_buckets = list()
    a_pids = a_ds.field_by_name('patient_id')
    a_os = a_ds.field_by_name('other_symptoms')
    for i_r in range(a_ds.row_count()):
        if a_pids[i_r] in buckets:
            a_buckets.append(buckets[a_pids[i_r]])
        else:
            orphaned_assessments += 1
            a_buckets.append(-1)

    del a_ds
    print('orphaned_assessments:', orphaned_assessments)

    print(f'{bucket_count + 1} buckets')
    for i in range(bucket_count + 1):
        print('bucket', i)
        destination_filename = assessment_data[:-4] + f"_{i:04d}" + ".csv"
        print(destination_filename)
        # with open(assessment_data) as f:
        #     a_ds = dataset.Dataset(f, filter_fn=lambda j: a_buckets[j] == i, show_progress_every=500000)
        #
        # del a_ds
        assessment_splitter(assessment_data, destination_filename, a_buckets,
                            i)

    print('done!')