Exemple #1
0
 def test_pass_indexes_num_elems(self):
     self.config.attributes.observables.types = {'gender': 'any'}
     self.config.observables_dict = load_observables_dict(self.config)
     indexes = pass_indexes(self.config, 'gender', 'any', 'any')
     self.assertEqual(len(indexes), 729)
Exemple #2
0
 def test_incorrect_observables(self):
     self.config.observables_dict = load_observables_dict(self.config)
     self.config.attributes.observables.types = {'some_obs': ''}
     self.assertRaises(ValueError, get_indexes, self.config)
Exemple #3
0
 def test_get_indexes_age_and_m(self):
     self.config.observables_dict = load_observables_dict(self.config)
     self.config.attributes.observables.types = {'age': (20, 22.1), 'gender': 'M'}
     self.assertEqual(len(get_indexes(self.config)), 14)
Exemple #4
0
 def test_incorrect_variable(self):
     self.config.observables_dict = load_observables_dict(self.config)
     self.assertRaises(ValueError, pass_indexes, self.config, 'age', 100, CommonTypes.any.value)
Exemple #5
0
 def test_len_list_pass_indexes(self):
     self.config.observables_dict = load_observables_dict(self.config)
     self.assertEqual(len(pass_indexes(self.config, 'age', 18, CommonTypes.any.value)), 19)
Exemple #6
0
 def test_get_indexes_num_f_m(self):
     self.config.attributes.observables.types = {'gender': ['F', 'M']}
     self.config.observables_dict = load_observables_dict(self.config)
     indexes = get_indexes(self.config)
     self.assertEqual(len(indexes), 729)
Exemple #7
0
 def test_subset_observables(self):
     self.config.observables_dict = load_observables_dict(self.config)
     self.config.observables_categorical_dict = load_observables_categorical_dict(self.config)
     self.config.attributes_indexes = list(range(5))
     subset_observables(self.config)
     self.assertEqual(self.config.observables_dict['gender'], ['M'] * 5)
Exemple #8
0
    def test_load_attributes_dict_check_sum_smoke(self):
        attributes_dict = load_observables_dict(self.config)

        sum_smoke = sum(list(map(int, attributes_dict['smoke'])))

        self.assertEqual(188, sum_smoke)
Exemple #9
0
 def test_load_attributes_dict_age_range(self):
     attributes_dict = load_observables_dict(self.config)
     self.assertEqual(
         max(attributes_dict['age']) - min(attributes_dict['age']), 80)
Exemple #10
0
 def test_load_attributes_dict_num_elems(self):
     attributes_dict = load_observables_dict(self.config)
     self.assertEqual(len(attributes_dict['age']), 729)
Exemple #11
0
def load_betas_adj(config):

    fn_dict = get_data_base_path(config) + '/' + 'betas_adj_dict.pkl'

    suffix = ''
    if bool(config.experiment.data_params):
        data_params = config.experiment.data_params
        suffix += '_' + config.experiment.get_data_params_str()
    else:
        raise ValueError(f'Exog for residuals is empty.')

    fn_data = get_data_base_path(config) + '/' + 'betas_adj' + suffix + '.npz'

    if os.path.isfile(fn_dict) and os.path.isfile(fn_data):

        f = open(fn_dict, 'rb')
        config.betas_adj_dict = pickle.load(f)
        f.close()

        data = np.load(fn_data)
        config.betas_adj_data = data['data']

    else:

        config.experiment.data_params = {}
        load_betas(config)

        config.betas_adj_dict = config.betas_dict
        f = open(fn_dict, 'wb')
        pickle.dump(config.betas_adj_dict, f, pickle.HIGHEST_PROTOCOL)
        f.close()

        exog_dict = {}

        if 'cells' in data_params:

            cells_dict = load_cells_dict(config)

            if isinstance(data_params['cells'], list):
                all_types = list(cells_dict.keys())
                for key in all_types:
                    if key not in data_params['cells']:
                        cells_dict.pop(key)

                if len(list(cells_dict.keys())) != len(data_params['cells']):
                    raise ValueError(f'Wrong number of cells types.')

                exog_dict.update(cells_dict)

        if 'observables' in data_params:

            observables_dict = load_observables_dict(config)
            if isinstance(data_params['observables'], list):
                all_types = list(observables_dict.keys())
                for key in all_types:
                    if key not in data_params['observables']:
                        observables_dict.pop(key)

                if len(list(observables_dict.keys())) != len(
                        data_params['observables']):
                    raise ValueError(f'Wrong number of observables types.')

                exog_dict.update(observables_dict)

        exog_df = pd.DataFrame(exog_dict)

        num_cpgs = config.betas_data.shape[0]
        num_subjects = config.betas_data.shape[1]
        config.betas_adj_data = np.zeros((num_cpgs, num_subjects),
                                         dtype=np.float32)

        for cpg, row in tqdm(config.betas_dict.items(),
                             mininterval=60.0,
                             desc='betas_adj_data creating'):
            betas = config.betas_data[row, :]

            mean = np.mean(betas)

            endog_dict = {cpg: betas}
            endog_df = pd.DataFrame(endog_dict)

            reg_res = sm.OLS(endog=endog_df, exog=exog_df).fit()

            residuals = list(map(np.float32, reg_res.resid))
            betas_adj = residuals + mean
            config.betas_adj_data[row] = betas_adj

        np.savez_compressed(fn_data, data=config.betas_adj_data)

        # Clear data
        del config.betas_data
Exemple #12
0
def load_residuals(config):

    suffix = ''
    if bool(config.experiment.data_params):
        data_params = config.experiment.data_params
        suffix += '_' + config.experiment.get_data_params_str()
    else:
        raise ValueError('Exog for residuals is empty.')

    fn_dict = get_data_base_path(
        config) + '/' + 'residuals_dict' + suffix + '.pkl'
    fn_missed_dict = get_data_base_path(
        config) + '/' + 'residuals_missed_dict' + suffix + '.pkl'
    fn_data = get_data_base_path(config) + '/' + 'residuals' + suffix + '.npz'

    if os.path.isfile(fn_dict) and os.path.isfile(fn_data):

        f = open(fn_dict, 'rb')
        config.residuals_dict = pickle.load(f)
        f.close()

        f = open(fn_missed_dict, 'rb')
        config.residuals_missed_dict = pickle.load(f)
        f.close()

        data = np.load(fn_data)
        config.residuals_data = data['data']

    else:

        data_params_copy = copy.deepcopy(config.experiment.data_params)
        common_keys = ['part', 'norm']
        config.experiment.data_params = {}
        for key in common_keys:
            if key in data_params_copy:
                config.experiment.data_params[key] = data_params_copy[key]

        load_betas(config)

        config.residuals_dict = config.betas_dict
        f = open(fn_dict, 'wb')
        pickle.dump(config.residuals_dict, f, pickle.HIGHEST_PROTOCOL)
        f.close()

        config.residuals_missed_dict = config.betas_missed_dict
        f = open(fn_missed_dict, 'wb')
        pickle.dump(config.residuals_missed_dict, f, pickle.HIGHEST_PROTOCOL)
        f.close()

        exog_dict = {}

        if 'cells' in data_params:

            cells_dict = load_cells_dict(config)

            if isinstance(data_params['cells'], list):
                all_types = list(cells_dict.keys())
                for key in all_types:
                    if key not in data_params['cells']:
                        cells_dict.pop(key)

                if len(list(cells_dict.keys())) != len(data_params['cells']):
                    raise ValueError('Wrong number of cells types.')

                exog_dict.update(cells_dict)

        if 'observables' in data_params:
            observables_dict = load_observables_dict(config)
            if isinstance(data_params['observables'], list):
                all_types = list(observables_dict.keys())
                for key in all_types:
                    if key not in data_params['observables']:
                        observables_dict.pop(key)

                if len(list(observables_dict.keys())) != len(
                        data_params['observables']):
                    raise ValueError('Wrong number of observables types.')

                exog_dict.update(observables_dict)

        exog_keys = []
        for exog_type, exog_data in exog_dict.items():
            if is_categorical(exog_data):
                exog_keys.append('C(' + exog_type + ')')
            else:
                exog_keys.append(exog_type)
        formula = 'cpg ~ ' + ' + '.join(exog_keys)

        num_cpgs = config.betas_data.shape[0]
        num_subjects = config.betas_data.shape[1]
        config.residuals_data = np.zeros((num_cpgs, num_subjects),
                                         dtype=np.float32)

        for cpg, row in tqdm(config.betas_dict.items(),
                             mininterval=60.0,
                             desc='residuals_data creating'):
            raw_betas = config.betas_data[row, :]

            current_data_dict = copy.deepcopy(exog_dict)

            if len(config.betas_missed_dict[cpg]) > 0:

                for key in current_data_dict:
                    values = []
                    for value_id in range(0, len(current_data_dict[key])):
                        if value_id not in config.betas_missed_dict[cpg]:
                            values.append(current_data_dict[key][value_id])
                    current_data_dict[key] = values

                betas = []
                passed_ids = []
                for beta_id in range(0, len(raw_betas)):
                    if beta_id not in config.betas_missed_dict[cpg]:
                        betas.append(raw_betas[beta_id])
                        passed_ids.append(beta_id)
            else:

                betas = raw_betas
                passed_ids = list(range(0, len(betas)))

            current_data_dict['cpg'] = betas
            data_df = pd.DataFrame(current_data_dict)
            reg_res = smf.ols(formula=formula, data=data_df).fit()

            residuals = list(map(np.float32, reg_res.resid))
            residuals_raw = np.zeros(num_subjects, dtype=np.float32)
            for beta_id in range(0, len(passed_ids)):
                residuals_raw[passed_ids[beta_id]] = residuals[beta_id]

            for missed_id in config.residuals_missed_dict[cpg]:
                residuals_raw[missed_id] = np.float32('nan')

            config.residuals_data[row] = residuals_raw

        np.savez_compressed(fn_data, data=config.residuals_data)

        # Clear data
        del config.betas_data