コード例 #1
0
ファイル: ppi_dataloader.py プロジェクト: sailfish009/atom3d
def _save_graphs(sharded, shard_num, out_dir):
    print(f'Processing shard {shard_num:}')
    shard = sharded.read_shard(shard_num)
    neighbors = sharded.read_shard(shard_num, 'neighbors')

    curr_idx = 0
    for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])):

        sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
        positives = neighbors[neighbors.ensemble0 == ensemble_name]
        negatives = nb.get_negatives(positives, bound1, bound2)
        negatives['label'] = 0
        labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
        
        for index, row in labels.iterrows():
            label = float(row['label'])
            chain_res1 = row[['chain0', 'residue0']].values
            chain_res2 = row[['chain1', 'residue1']].values
            graph1 = df_to_graph(bound1, chain_res1, label)
            graph2 = df_to_graph(bound2, chain_res2, label)
            if (graph1 is None) or (graph2 is None):
                continue

            pair = Batch.from_data_list([graph1, graph2])
            torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt'))
            curr_idx += 1
コード例 #2
0
ファイル: data.py プロジェクト: maschka/atom3d
def dataset_generator(dataset, indices, shuffle=True):
    """
    Generator that convert sharded HDF dataset to graphs
    """
    for idx in indices:
        data = dataset[idx]

        neighbors = data['atoms_neighbors']
        pairs = data['atoms_pairs']

        if shuffle:
            groups = [df for _, df in pairs.groupby('ensemble')]
            random.shuffle(groups)
            shard = pd.concat(groups).reset_index(drop=True)

        for i, (ensemble_name, target_df) in enumerate(pairs.groupby(['ensemble'])):

            sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
            positives = neighbors[neighbors.ensemble0 == ensemble_name]
            negatives = nb.get_negatives(positives, bound1, bound2)
            negatives['label'] = 0
            labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1)
            
            for index, row in labels.iterrows():
                label = float(row['label'])
                chain_res1 = row[['chain0', 'residue0']].values
                chain_res2 = row[['chain1', 'residue1']].values
                graph1 = df_to_graph(bound1, chain_res1, label)
                graph2 = df_to_graph(bound2, chain_res2, label)
                if (graph1 is None) or (graph2 is None):
                    continue
                yield graph1, graph2
コード例 #3
0
ファイル: ppi_dataloader.py プロジェクト: everyday847/atom3d
def dataset_generator(sharded, shard_indices, shuffle=True):
    """
    Generator that convert sharded HDF dataset to graphs
    """
    for shard_idx in shard_indices:
        shard = sharded.read_shard(shard_idx)

        neighbors = sharded.read_shard(shard_idx, 'neighbors')

        if shuffle:
            groups = [df for _, df in shard.groupby('ensemble')]
            random.shuffle(groups)
            shard = pd.concat(groups).reset_index(drop=True)

        for i, (ensemble_name,
                target_df) in enumerate(shard.groupby(['ensemble'])):

            sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df)
            # bound1 = remove_waters(bound1)
            # bound2 = remove_waters(bound2)
            positives = neighbors[neighbors.ensemble0 == ensemble_name]
            negatives = nb.get_negatives(positives, bound1, bound2)
            negatives['label'] = 0
            labels = create_labels(positives, negatives, neg_pos_ratio=1)
            labels = labels.sample(frac=1)

            for index, row in labels.iterrows():
                label = float(row['label'])
                chain_res1 = row[['chain0', 'residue0']].values
                chain_res2 = row[['chain1', 'residue1']].values
                graph1 = df_to_graph(bound1, chain_res1, label)
                if graph1 is None:
                    continue
                graph2 = df_to_graph(bound2, chain_res2, label)
                if graph2 is None:
                    continue
                yield graph1, graph2
コード例 #4
0
    def __iter__(self):
        for index in range(len(self._lmdb_dataset)):
            item = self._lmdb_dataset[index]

            # Subunits
            names, (bdf0, bdf1, udf0,
                    udf1) = nb.get_subunits(item['atoms_pairs'])
            structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1]
            # Get positives
            pos_neighbors_df = item['atoms_neighbors']
            # Get negatives
            neg_neighbors_df = nb.get_negatives(pos_neighbors_df,
                                                structs_df[0], structs_df[1])

            # Throw away non empty hetero/insertion_code
            non_heteros = []
            for df in structs_df:
                non_heteros.append(
                    df[(df.hetero == ' ')
                       & (df.insertion_code == ' ')].residue.unique())
            pos_neighbors_df = pos_neighbors_df[pos_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                pos_neighbors_df.residue1.isin(non_heteros[1])]
            neg_neighbors_df = neg_neighbors_df[neg_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                neg_neighbors_df.residue1.isin(non_heteros[1])]

            # Sample pos and neg samples
            num_pos = pos_neighbors_df.shape[0]
            num_neg = neg_neighbors_df.shape[0]
            num_pos_to_use, num_neg_to_use = self._num_to_use(num_pos, num_neg)

            if pos_neighbors_df.shape[0] == num_pos_to_use:
                pos_samples_df = pos_neighbors_df.reset_index(drop=True)
            else:
                pos_samples_df = pos_neighbors_df.sample(
                    num_pos_to_use, replace=True).reset_index(drop=True)
            if neg_neighbors_df.shape[0] == num_neg_to_use:
                neg_samples_df = neg_neighbors_df.reset_index(drop=True)
            else:
                neg_samples_df = neg_neighbors_df.sample(
                    num_neg_to_use, replace=True).reset_index(drop=True)

            pos_pairs_cas = self._get_res_pair_ca_coords(
                pos_samples_df, structs_df)
            neg_pairs_cas = self._get_res_pair_ca_coords(
                neg_samples_df, structs_df)

            pos_features = []
            for (res0, res1, center0, center1) in pos_pairs_cas:
                grid0, grid1 = self._voxelize(structs_df[0], structs_df[1],
                                              center0, center1)
                pos_features.append({
                    'feature_left':
                    grid0,
                    'feature_right':
                    grid1,
                    'label':
                    1,
                    'id':
                    '{:}/{:}/{:}'.format(item['id'], res0, res1),
                })
            neg_features = []
            for (res0, res1, center0, center1) in neg_pairs_cas:
                grid0, grid1 = self._voxelize(structs_df[0], structs_df[1],
                                              center0, center1)
                neg_features.append({
                    'feature_left':
                    grid0,
                    'feature_right':
                    grid1,
                    'label':
                    0,
                    'id':
                    '{:}/{:}/{:}'.format(item['id'], res0, res1),
                })
            for f in intersperse(pos_features, neg_features):
                yield f
コード例 #5
0
ファイル: feature_ppi.py プロジェクト: sailfish009/atom3d
def dataset_generator(sharded,
                      grid_config,
                      shuffle=True,
                      repeat=None,
                      max_num_ensembles=None,
                      testing=False,
                      use_shard_nums=None,
                      random_seed=None):

    if use_shard_nums is None:
        num_shards = sharded.get_num_shards()
        all_shard_nums = np.arange(num_shards)
    else:
        all_shard_nums = use_shard_nums

    seen = col.defaultdict(set)
    ensemble_count = 0

    if repeat == None:
        repeat = 1
    for epoch in range(repeat):
        if shuffle:
            p = np.random.permutation(len(all_shard_nums))
            all_shard_nums = all_shard_nums[p]
        shard_nums = all_shard_nums

        for i, shard_num in enumerate(shard_nums):
            shard_neighbors_df = sharded.read_shard(shard_num, key='neighbors')
            shard_structs_df = sharded.read_shard(shard_num, key='structures')

            ensemble_names = shard_structs_df.ensemble.unique()
            if len(seen[shard_num]) == len(ensemble_names):
                seen[shard_num] = set()

            if shuffle:
                p = np.random.permutation(len(ensemble_names))
                ensemble_names = ensemble_names[p]

            for j, ensemble_name in enumerate(ensemble_names):
                if (max_num_ensembles != None) and (ensemble_count >=
                                                    max_num_ensembles):
                    return
                if ensemble_name in seen[shard_num]:
                    continue
                seen[shard_num].add(ensemble_name)
                ensemble_count += 1

                ensemble_df = shard_structs_df[shard_structs_df.ensemble ==
                                               ensemble_name]
                # Subunits
                names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble_df)
                structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1]
                # Get positives
                pos_neighbors_df = shard_neighbors_df[
                    shard_neighbors_df.ensemble0 == ensemble_name]
                # Get negatives
                neg_neighbors_df = nb.get_negatives(pos_neighbors_df,
                                                    structs_df[0],
                                                    structs_df[1])

                # Throw away non empty hetero/insertion_code
                non_heteros = []
                for df in structs_df:
                    non_heteros.append(
                        df[(df.hetero == ' ')
                           & (df.insertion_code == ' ')].residue.unique())
                pos_neighbors_df = pos_neighbors_df[pos_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                    pos_neighbors_df.residue1.isin(non_heteros[1])]
                neg_neighbors_df = neg_neighbors_df[neg_neighbors_df.residue0.isin(non_heteros[0]) & \
                                                    neg_neighbors_df.residue1.isin(non_heteros[1])]

                # Sample pos and neg samples
                num_pos = pos_neighbors_df.shape[0]
                num_neg = neg_neighbors_df.shape[0]
                num_pos_to_use, num_neg_to_use = __num_to_use(
                    num_pos, num_neg, testing, grid_config)

                if shuffle:
                    pos_neighbors_df = pos_neighbors_df.sample(
                        frac=1).reset_index(drop=True)
                    neg_neighbors_df = neg_neighbors_df.sample(
                        frac=1).reset_index(drop=True)
                if pos_neighbors_df.shape[0] == num_pos_to_use:
                    pos_samples_df = pos_neighbors_df.reset_index(drop=True)
                else:
                    pos_samples_df = pos_neighbors_df.sample(
                        num_pos_to_use, replace=True).reset_index(drop=True)
                if neg_neighbors_df.shape[0] == num_neg_to_use:
                    neg_samples_df = neg_neighbors_df.reset_index(drop=True)
                else:
                    neg_samples_df = neg_neighbors_df.sample(
                        num_neg_to_use, replace=True).reset_index(drop=True)

                pos_pairs_cas = __get_res_pair_ca_coords(
                    pos_samples_df, structs_df)
                neg_pairs_cas = __get_res_pair_ca_coords(
                    neg_samples_df, structs_df)

                pos_features = []
                for (res0, res1, center0, center1) in pos_pairs_cas:
                    fgrid = df_to_feature(structs_df[0], structs_df[1],
                                          center0, center1, grid_config,
                                          random_seed)
                    pos_features.append(
                        ('{:}/{:}/{:}'.format(ensemble_name, res0,
                                              res1), fgrid, np.array([1])))

                neg_features = []
                for (res0, res1, center0, center1) in neg_pairs_cas:
                    fgrid = df_to_feature(structs_df[0], structs_df[1],
                                          center0, center1, grid_config,
                                          random_seed)
                    neg_features.append(
                        ('{:}/{:}/{:}'.format(ensemble_name, res0,
                                              res1), fgrid, np.array([0])))

                for f in util.intersperse(pos_features, neg_features):
                    yield f