コード例 #1
0
ファイル: ppi_dataloader.py プロジェクト: sailfish009/atom3d
def _save_graphs(sharded, shard_num, out_dir):
    print(f'Processing shard {shard_num:}')
    shard = sharded.read_shard(shard_num)
    neighbors = sharded.read_shard(shard_num, 'neighbors')

    curr_idx = 0
    for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])):

        sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
        positives = neighbors[neighbors.ensemble0 == ensemble_name]
        negatives = nb.get_negatives(positives, bound1, bound2)
        negatives['label'] = 0
        labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
        
        for index, row in labels.iterrows():
            label = float(row['label'])
            chain_res1 = row[['chain0', 'residue0']].values
            chain_res2 = row[['chain1', 'residue1']].values
            graph1 = df_to_graph(bound1, chain_res1, label)
            graph2 = df_to_graph(bound2, chain_res2, label)
            if (graph1 is None) or (graph2 is None):
                continue

            pair = Batch.from_data_list([graph1, graph2])
            torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt'))
            curr_idx += 1
コード例 #2
0
ファイル: data.py プロジェクト: maschka/atom3d
def dataset_generator(dataset, indices, shuffle=True):
    """
    Generator that convert sharded HDF dataset to graphs
    """
    for idx in indices:
        data = dataset[idx]

        neighbors = data['atoms_neighbors']
        pairs = data['atoms_pairs']

        if shuffle:
            groups = [df for _, df in pairs.groupby('ensemble')]
            random.shuffle(groups)
            shard = pd.concat(groups).reset_index(drop=True)

        for i, (ensemble_name, target_df) in enumerate(pairs.groupby(['ensemble'])):

            sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
            positives = neighbors[neighbors.ensemble0 == ensemble_name]
            negatives = nb.get_negatives(positives, bound1, bound2)
            negatives['label'] = 0
            labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
            
            for index, row in labels.iterrows():
                label = float(row['label'])
                chain_res1 = row[['chain0', 'residue0']].values
                chain_res2 = row[['chain1', 'residue1']].values
                graph1 = df_to_graph(bound1, chain_res1, label)
                graph2 = df_to_graph(bound2, chain_res2, label)
                if (graph1 is None) or (graph2 is None):
                    continue
                yield graph1, graph2
コード例 #3
0
    def filter_fn(df):
        to_keep = {}
        for e, ensemble in df.groupby(['ensemble']):
            names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble)
            chains0 = bdf0[['structure', 'chain']].drop_duplicates()
            chains1 = bdf1[['structure', 'chain']].drop_duplicates()
            chains0['pdb_code'] = chains0['structure'].apply(
                lambda x: fi.get_pdb_code(x).lower())
            chains1['pdb_code'] = chains1['structure'].apply(
                lambda x: fi.get_pdb_code(x).lower())

            scop0, scop1 = [], []
            for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy():
                if (pc, c) in scop_index:
                    scop0.append(scop_index.loc[(pc, c)].values)
            for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy():
                if (pc, c) in scop_index:
                    scop1.append(scop_index.loc[(pc, c)].values)
            scop0 = list(np.unique(np.concatenate(scop0))) \
                if len(scop0) > 0 else []
            scop1 = list(np.unique(np.concatenate(scop1))) \
                if len(scop1) > 0 else []
            pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1]

            to_keep[e] = True
            for p in pairs:
                if p in scop_pairs:
                    to_keep[e] = False
        to_keep = pd.Series(to_keep)[df['ensemble']]
        return df[to_keep.values]
コード例 #4
0
ファイル: feature_ppi.py プロジェクト: sailfish009/atom3d
def get_data_stats(sharded_list):
    data = []
    for i, sharded in enumerate(sharded_list):
        num_shards = sharded.get_num_shards()
        for _, shard_num in enumerate(range(num_shards)):
            shard_neighbors_df = sharded.read_shard(shard_num, key='neighbors')
            shard_structs_df = sharded.read_shard(shard_num, key='structures')

            ensemble_names = shard_structs_df.ensemble.unique()
            for _, ensemble_name in enumerate(ensemble_names):
                ensemble_df = shard_structs_df[shard_structs_df.ensemble ==
                                               ensemble_name]
                # Subunits
                names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble_df)
                structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1]
                # Get positives
                pos_neighbors_df = shard_neighbors_df[
                    shard_neighbors_df.ensemble0 == ensemble_name]
                num_pos = pos_neighbors_df.shape[0]
                data.append((i, shard_num, ensemble_name, num_pos))

    df = pd.DataFrame(data,
                      columns=['sharded', 'shard_num', 'ensemble', 'num_pos'])
    df = df.sort_values(['sharded', 'num_pos'],
                        ascending=[True, False]).reset_index(drop=True)
    print(df.describe())
    return df
コード例 #5
0
ファイル: ppi_dataloader.py プロジェクト: everyday847/atom3d
def dataset_generator(sharded, shard_indices, shuffle=True):
    """
    Generator that convert sharded HDF dataset to graphs
    """
    for shard_idx in shard_indices:
        shard = sharded.read_shard(shard_idx)

        neighbors = sharded.read_shard(shard_idx, 'neighbors')

        if shuffle:
            groups = [df for _, df in shard.groupby('ensemble')]
            random.shuffle(groups)
            shard = pd.concat(groups).reset_index(drop=True)

        for i, (ensemble_name,
                target_df) in enumerate(shard.groupby(['ensemble'])):

            sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
            # bound1 = remove_waters(bound1)
            # bound2 = remove_waters(bound2)
            positives = neighbors[neighbors.ensemble0 == ensemble_name]
            negatives = nb.get_negatives(positives, bound1, bound2)
            negatives['label'] = 0
            labels = create_labels(positives, negatives, neg_pos_ratio=1)
            labels = labels.sample(frac=1)

            for index, row in labels.iterrows():
                label = float(row['label'])
                chain_res1 = row[['chain0', 'residue0']].values
                chain_res2 = row[['chain1', 'residue1']].values
                graph1 = df_to_graph(bound1, chain_res1, label)
                if graph1 is None:
                    continue
                graph2 = df_to_graph(bound2, chain_res2, label)
                if graph2 is None:
                    continue
                yield graph1, graph2
コード例 #6
0
    def __iter__(self):
        for index in range(len(self._lmdb_dataset)):
            item = self._lmdb_dataset[index]

            # Subunits
            names, (bdf0, bdf1, udf0,
                    udf1) = nb.get_subunits(item['atoms_pairs'])
            structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1]
            # Get positives
            pos_neighbors_df = item['atoms_neighbors']
            # Get negatives
            neg_neighbors_df = nb.get_negatives(pos_neighbors_df,
                                                structs_df[0], structs_df[1])

            # Throw away non empty hetero/insertion_code
            non_heteros = []
            for df in structs_df:
                non_heteros.append(
                    df[(df.hetero == ' ')
                       & (df.insertion_code == ' ')].residue.unique())
            pos_neighbors_df = pos_neighbors_df[pos_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                pos_neighbors_df.residue1.isin(non_heteros[1])]
            neg_neighbors_df = neg_neighbors_df[neg_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                neg_neighbors_df.residue1.isin(non_heteros[1])]

            # Sample pos and neg samples
            num_pos = pos_neighbors_df.shape[0]
            num_neg = neg_neighbors_df.shape[0]
            num_pos_to_use, num_neg_to_use = self._num_to_use(num_pos, num_neg)

            if pos_neighbors_df.shape[0] == num_pos_to_use:
                pos_samples_df = pos_neighbors_df.reset_index(drop=True)
            else:
                pos_samples_df = pos_neighbors_df.sample(
                    num_pos_to_use, replace=True).reset_index(drop=True)
            if neg_neighbors_df.shape[0] == num_neg_to_use:
                neg_samples_df = neg_neighbors_df.reset_index(drop=True)
            else:
                neg_samples_df = neg_neighbors_df.sample(
                    num_neg_to_use, replace=True).reset_index(drop=True)

            pos_pairs_cas = self._get_res_pair_ca_coords(
                pos_samples_df, structs_df)
            neg_pairs_cas = self._get_res_pair_ca_coords(
                neg_samples_df, structs_df)

            pos_features = []
            for (res0, res1, center0, center1) in pos_pairs_cas:
                grid0, grid1 = self._voxelize(structs_df[0], structs_df[1],
                                              center0, center1)
                pos_features.append({
                    'feature_left':
                    grid0,
                    'feature_right':
                    grid1,
                    'label':
                    1,
                    'id':
                    '{:}/{:}/{:}'.format(item['id'], res0, res1),
                })
            neg_features = []
            for (res0, res1, center0, center1) in neg_pairs_cas:
                grid0, grid1 = self._voxelize(structs_df[0], structs_df[1],
                                              center0, center1)
                neg_features.append({
                    'feature_left':
                    grid0,
                    'feature_right':
                    grid1,
                    'label':
                    0,
                    'id':
                    '{:}/{:}/{:}'.format(item['id'], res0, res1),
                })
            for f in intersperse(pos_features, neg_features):
                yield f
コード例 #7
0
ファイル: feature_ppi.py プロジェクト: sailfish009/atom3d
def dataset_generator(sharded,
                      grid_config,
                      shuffle=True,
                      repeat=None,
                      max_num_ensembles=None,
                      testing=False,
                      use_shard_nums=None,
                      random_seed=None):

    if use_shard_nums is None:
        num_shards = sharded.get_num_shards()
        all_shard_nums = np.arange(num_shards)
    else:
        all_shard_nums = use_shard_nums

    seen = col.defaultdict(set)
    ensemble_count = 0

    if repeat == None:
        repeat = 1
    for epoch in range(repeat):
        if shuffle:
            p = np.random.permutation(len(all_shard_nums))
            all_shard_nums = all_shard_nums[p]
        shard_nums = all_shard_nums

        for i, shard_num in enumerate(shard_nums):
            shard_neighbors_df = sharded.read_shard(shard_num, key='neighbors')
            shard_structs_df = sharded.read_shard(shard_num, key='structures')

            ensemble_names = shard_structs_df.ensemble.unique()
            if len(seen[shard_num]) == len(ensemble_names):
                seen[shard_num] = set()

            if shuffle:
                p = np.random.permutation(len(ensemble_names))
                ensemble_names = ensemble_names[p]

            for j, ensemble_name in enumerate(ensemble_names):
                if (max_num_ensembles != None) and (ensemble_count >=
                                                    max_num_ensembles):
                    return
                if ensemble_name in seen[shard_num]:
                    continue
                seen[shard_num].add(ensemble_name)
                ensemble_count += 1

                ensemble_df = shard_structs_df[shard_structs_df.ensemble ==
                                               ensemble_name]
                # Subunits
                names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble_df)
                structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1]
                # Get positives
                pos_neighbors_df = shard_neighbors_df[
                    shard_neighbors_df.ensemble0 == ensemble_name]
                # Get negatives
                neg_neighbors_df = nb.get_negatives(pos_neighbors_df,
                                                    structs_df[0],
                                                    structs_df[1])

                # Throw away non empty hetero/insertion_code
                non_heteros = []
                for df in structs_df:
                    non_heteros.append(
                        df[(df.hetero == ' ')
                           & (df.insertion_code == ' ')].residue.unique())
                pos_neighbors_df = pos_neighbors_df[pos_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                    pos_neighbors_df.residue1.isin(non_heteros[1])]
                neg_neighbors_df = neg_neighbors_df[neg_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                    neg_neighbors_df.residue1.isin(non_heteros[1])]

                # Sample pos and neg samples
                num_pos = pos_neighbors_df.shape[0]
                num_neg = neg_neighbors_df.shape[0]
                num_pos_to_use, num_neg_to_use = __num_to_use(
                    num_pos, num_neg, testing, grid_config)

                if shuffle:
                    pos_neighbors_df = pos_neighbors_df.sample(
                        frac=1).reset_index(drop=True)
                    neg_neighbors_df = neg_neighbors_df.sample(
                        frac=1).reset_index(drop=True)
                if pos_neighbors_df.shape[0] == num_pos_to_use:
                    pos_samples_df = pos_neighbors_df.reset_index(drop=True)
                else:
                    pos_samples_df = pos_neighbors_df.sample(
                        num_pos_to_use, replace=True).reset_index(drop=True)
                if neg_neighbors_df.shape[0] == num_neg_to_use:
                    neg_samples_df = neg_neighbors_df.reset_index(drop=True)
                else:
                    neg_samples_df = neg_neighbors_df.sample(
                        num_neg_to_use, replace=True).reset_index(drop=True)

                pos_pairs_cas = __get_res_pair_ca_coords(
                    pos_samples_df, structs_df)
                neg_pairs_cas = __get_res_pair_ca_coords(
                    neg_samples_df, structs_df)

                pos_features = []
                for (res0, res1, center0, center1) in pos_pairs_cas:
                    fgrid = df_to_feature(structs_df[0], structs_df[1],
                                          center0, center1, grid_config,
                                          random_seed)
                    pos_features.append(
                        ('{:}/{:}/{:}'.format(ensemble_name, res0,
                                              res1), fgrid, np.array([1])))

                neg_features = []
                for (res0, res1, center0, center1) in neg_pairs_cas:
                    fgrid = df_to_feature(structs_df[0], structs_df[1],
                                          center0, center1, grid_config,
                                          random_seed)
                    neg_features.append(
                        ('{:}/{:}/{:}'.format(ensemble_name, res0,
                                              res1), fgrid, np.array([0])))

                for f in util.intersperse(pos_features, neg_features):
                    yield f
コード例 #8
0
def form_scop_pair_filter_against(sharded, level):
    """Remove pairs that have matching scop classes in both subunits."""

    scop_index = scop.get_scop_index()[level]

    scop_pairs = []
    for _, shard in sh.iter_shards(sharded):
        for e, ensemble in shard.groupby(['ensemble']):
            names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble)
            chains0 = bdf0[['structure', 'chain']].drop_duplicates()
            chains1 = bdf1[['structure', 'chain']].drop_duplicates()
            chains0['pdb_code'] = chains0['structure'].apply(
                lambda x: fi.get_pdb_code(x).lower())
            chains1['pdb_code'] = chains1['structure'].apply(
                lambda x: fi.get_pdb_code(x).lower())

            scop0, scop1 = [], []
            for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy():
                if (pc, c) in scop_index:
                    scop0.append(scop_index.loc[(pc, c)].values)
            for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy():
                if (pc, c) in scop_index:
                    scop1.append(scop_index.loc[(pc, c)].values)
            scop0 = list(np.unique(np.concatenate(scop0))) \
                if len(scop0) > 0 else []
            scop1 = list(np.unique(np.concatenate(scop1))) \
                if len(scop1) > 0 else []
            pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1]
            scop_pairs.extend(pairs)
    scop_pairs = set(scop_pairs)

    def filter_fn(df):
        to_keep = {}
        for e, ensemble in df.groupby(['ensemble']):
            names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble)
            chains0 = bdf0[['structure', 'chain']].drop_duplicates()
            chains1 = bdf1[['structure', 'chain']].drop_duplicates()
            chains0['pdb_code'] = chains0['structure'].apply(
                lambda x: fi.get_pdb_code(x).lower())
            chains1['pdb_code'] = chains1['structure'].apply(
                lambda x: fi.get_pdb_code(x).lower())

            scop0, scop1 = [], []
            for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy():
                if (pc, c) in scop_index:
                    scop0.append(scop_index.loc[(pc, c)].values)
            for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy():
                if (pc, c) in scop_index:
                    scop1.append(scop_index.loc[(pc, c)].values)
            scop0 = list(np.unique(np.concatenate(scop0))) \
                if len(scop0) > 0 else []
            scop1 = list(np.unique(np.concatenate(scop1))) \
                if len(scop1) > 0 else []
            pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1]

            to_keep[e] = True
            for p in pairs:
                if p in scop_pairs:
                    to_keep[e] = False
        to_keep = pd.Series(to_keep)[df['ensemble']]
        return df[to_keep.values]

    return filter_fn
コード例 #9
0
def compute_all_bsa_main(sharded_path, ensemble):
    sharded = sh.Sharded.load(sharded_path)
    ensemble = sharded.read_ensemble(ensemble)
    _, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble)
    print(compute_bsa(bdf0, bdf1))
コード例 #10
0
def _bsa_db(sharded, shard_num, output_bsa):
    logger.info(f'Processing shard {shard_num:}')
    start_time = timeit.default_timer()
    start_time_reading = timeit.default_timer()
    shard = sharded.read_shard(shard_num)
    elapsed_reading = timeit.default_timer() - start_time_reading

    start_time_waiting = timeit.default_timer()
    with db_sem:
        start_time_reading = timeit.default_timer()
        if os.path.exists(output_bsa):
            curr_bsa_db = pd.read_csv(output_bsa).set_index(['ensemble'])
        else:
            curr_bsa_db = None
        tmp_elapsed_reading = timeit.default_timer() - start_time_reading
    elapsed_waiting = timeit.default_timer() - start_time_waiting - \
        tmp_elapsed_reading
    elapsed_reading += tmp_elapsed_reading

    start_time_processing = timeit.default_timer()
    all_results = []
    cache = {}
    for e, ensemble in shard.groupby('ensemble'):
        if (curr_bsa_db is not None) and (e in curr_bsa_db.index):
            continue
        (name0, name1, _, _), (bdf0, bdf1, _, _) = nb.get_subunits(ensemble)

        try:
            # We use bound for indiviudal subunits in bsa computation, as
            # sometimes the actual structure between bound and unbound differ.
            if name0 not in cache:
                cache[name0] = bsa._compute_asa(bdf0)
            if name1 not in cache:
                cache[name1] = bsa._compute_asa(bdf1)
            result = bsa.compute_bsa(bdf0, bdf1, cache[name0], cache[name1])
            result['ensemble'] = e
            all_results.append(result)
        except AssertionError as e:
            logger.warning(e)
            logger.warning(f'Failed BSA on {e:}')

    if len(all_results) > 0:
        to_add = pd.concat(all_results, axis=1).T
    elapsed_processing = timeit.default_timer() - start_time_processing

    if len(all_results) > 0:
        start_time_waiting = timeit.default_timer()
        with db_sem:
            start_time_writing = timeit.default_timer()
            # Update db in case it has updated since last run.
            if os.path.exists(output_bsa):
                curr_bsa_db = pd.read_csv(output_bsa)
                new_bsa_db = pd.concat([curr_bsa_db, to_add])
            else:
                new_bsa_db = to_add
            new_bsa_db.to_csv(output_bsa + f'.tmp{shard_num:}', index=False)
            os.rename(output_bsa + f'.tmp{shard_num:}', output_bsa)
            elapsed_writing = timeit.default_timer() - start_time_writing
        elapsed_waiting += timeit.default_timer() - start_time_waiting - \
            elapsed_writing
    else:
        elapsed_writing = 0
    elapsed = timeit.default_timer() - start_time

    logger.info(
        f'For {len(all_results):03d} pairs buried in shard {shard_num:} spent '
        f'{elapsed_reading:05.2f} reading, '
        f'{elapsed_processing:05.2f} processing, '
        f'{elapsed_writing:05.2f} writing, '
        f'{elapsed_waiting:05.2f} waiting, and '
        f'{elapsed:05.2f} overall.')