def _save_graphs(sharded, shard_num, out_dir): print(f'Processing shard {shard_num:}') shard = sharded.read_shard(shard_num) neighbors = sharded.read_shard(shard_num, 'neighbors') curr_idx = 0 for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])): sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) positives = neighbors[neighbors.ensemble0 == ensemble_name] negatives = nb.get_negatives(positives, bound1, bound2) negatives['label'] = 0 labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1) for index, row in labels.iterrows(): label = float(row['label']) chain_res1 = row[['chain0', 'residue0']].values chain_res2 = row[['chain1', 'residue1']].values graph1 = df_to_graph(bound1, chain_res1, label) graph2 = df_to_graph(bound2, chain_res2, label) if (graph1 is None) or (graph2 is None): continue pair = Batch.from_data_list([graph1, graph2]) torch.save(pair, os.path.join(out_dir, f'data_{shard_num}_{curr_idx}.pt')) curr_idx += 1
def dataset_generator(dataset, indices, shuffle=True): """ Generator that convert sharded HDF dataset to graphs """ for idx in indices: data = dataset[idx] neighbors = data['atoms_neighbors'] pairs = data['atoms_pairs'] if shuffle: groups = [df for _, df in pairs.groupby('ensemble')] random.shuffle(groups) shard = pd.concat(groups).reset_index(drop=True) for i, (ensemble_name, target_df) in enumerate(pairs.groupby(['ensemble'])): sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) positives = neighbors[neighbors.ensemble0 == ensemble_name] negatives = nb.get_negatives(positives, bound1, bound2) negatives['label'] = 0 labels = create_labels(positives, negatives, num_pos=10, neg_pos_ratio=1) for index, row in labels.iterrows(): label = float(row['label']) chain_res1 = row[['chain0', 'residue0']].values chain_res2 = row[['chain1', 'residue1']].values graph1 = df_to_graph(bound1, chain_res1, label) graph2 = df_to_graph(bound2, chain_res2, label) if (graph1 is None) or (graph2 is None): continue yield graph1, graph2
def filter_fn(df): to_keep = {} for e, ensemble in df.groupby(['ensemble']): names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) chains0 = bdf0[['structure', 'chain']].drop_duplicates() chains1 = bdf1[['structure', 'chain']].drop_duplicates() chains0['pdb_code'] = chains0['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) chains1['pdb_code'] = chains1['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) scop0, scop1 = [], [] for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop0.append(scop_index.loc[(pc, c)].values) for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop1.append(scop_index.loc[(pc, c)].values) scop0 = list(np.unique(np.concatenate(scop0))) \ if len(scop0) > 0 else [] scop1 = list(np.unique(np.concatenate(scop1))) \ if len(scop1) > 0 else [] pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1] to_keep[e] = True for p in pairs: if p in scop_pairs: to_keep[e] = False to_keep = pd.Series(to_keep)[df['ensemble']] return df[to_keep.values]
def get_data_stats(sharded_list): data = [] for i, sharded in enumerate(sharded_list): num_shards = sharded.get_num_shards() for _, shard_num in enumerate(range(num_shards)): shard_neighbors_df = sharded.read_shard(shard_num, key='neighbors') shard_structs_df = sharded.read_shard(shard_num, key='structures') ensemble_names = shard_structs_df.ensemble.unique() for _, ensemble_name in enumerate(ensemble_names): ensemble_df = shard_structs_df[shard_structs_df.ensemble == ensemble_name] # Subunits names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble_df) structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1] # Get positives pos_neighbors_df = shard_neighbors_df[ shard_neighbors_df.ensemble0 == ensemble_name] num_pos = pos_neighbors_df.shape[0] data.append((i, shard_num, ensemble_name, num_pos)) df = pd.DataFrame(data, columns=['sharded', 'shard_num', 'ensemble', 'num_pos']) df = df.sort_values(['sharded', 'num_pos'], ascending=[True, False]).reset_index(drop=True) print(df.describe()) return df
def dataset_generator(sharded, shard_indices, shuffle=True): """ Generator that convert sharded HDF dataset to graphs """ for shard_idx in shard_indices: shard = sharded.read_shard(shard_idx) neighbors = sharded.read_shard(shard_idx, 'neighbors') if shuffle: groups = [df for _, df in shard.groupby('ensemble')] random.shuffle(groups) shard = pd.concat(groups).reset_index(drop=True) for i, (ensemble_name, target_df) in enumerate(shard.groupby(['ensemble'])): sub_names, (bound1, bound2, _, _) = nb.get_subunits(target_df) # bound1 = remove_waters(bound1) # bound2 = remove_waters(bound2) positives = neighbors[neighbors.ensemble0 == ensemble_name] negatives = nb.get_negatives(positives, bound1, bound2) negatives['label'] = 0 labels = create_labels(positives, negatives, neg_pos_ratio=1) labels = labels.sample(frac=1) for index, row in labels.iterrows(): label = float(row['label']) chain_res1 = row[['chain0', 'residue0']].values chain_res2 = row[['chain1', 'residue1']].values graph1 = df_to_graph(bound1, chain_res1, label) if graph1 is None: continue graph2 = df_to_graph(bound2, chain_res2, label) if graph2 is None: continue yield graph1, graph2
def __iter__(self): for index in range(len(self._lmdb_dataset)): item = self._lmdb_dataset[index] # Subunits names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(item['atoms_pairs']) structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1] # Get positives pos_neighbors_df = item['atoms_neighbors'] # Get negatives neg_neighbors_df = nb.get_negatives(pos_neighbors_df, structs_df[0], structs_df[1]) # Throw away non empty hetero/insertion_code non_heteros = [] for df in structs_df: non_heteros.append( df[(df.hetero == ' ') & (df.insertion_code == ' ')].residue.unique()) pos_neighbors_df = pos_neighbors_df[pos_neighbors_df.residue0.isin(non_heteros[0]) & \ pos_neighbors_df.residue1.isin(non_heteros[1])] neg_neighbors_df = neg_neighbors_df[neg_neighbors_df.residue0.isin(non_heteros[0]) & \ neg_neighbors_df.residue1.isin(non_heteros[1])] # Sample pos and neg samples num_pos = pos_neighbors_df.shape[0] num_neg = neg_neighbors_df.shape[0] num_pos_to_use, num_neg_to_use = self._num_to_use(num_pos, num_neg) if pos_neighbors_df.shape[0] == num_pos_to_use: pos_samples_df = pos_neighbors_df.reset_index(drop=True) else: pos_samples_df = pos_neighbors_df.sample( num_pos_to_use, replace=True).reset_index(drop=True) if neg_neighbors_df.shape[0] == num_neg_to_use: neg_samples_df = neg_neighbors_df.reset_index(drop=True) else: neg_samples_df = neg_neighbors_df.sample( num_neg_to_use, replace=True).reset_index(drop=True) pos_pairs_cas = self._get_res_pair_ca_coords( pos_samples_df, structs_df) neg_pairs_cas = self._get_res_pair_ca_coords( neg_samples_df, structs_df) pos_features = [] for (res0, res1, center0, center1) in pos_pairs_cas: grid0, grid1 = self._voxelize(structs_df[0], structs_df[1], center0, center1) pos_features.append({ 'feature_left': grid0, 'feature_right': grid1, 'label': 1, 'id': '{:}/{:}/{:}'.format(item['id'], res0, res1), }) neg_features = [] for (res0, res1, center0, center1) in neg_pairs_cas: grid0, grid1 = self._voxelize(structs_df[0], structs_df[1], center0, center1) neg_features.append({ 'feature_left': grid0, 'feature_right': grid1, 'label': 0, 'id': '{:}/{:}/{:}'.format(item['id'], res0, res1), }) for f in intersperse(pos_features, neg_features): yield f
def dataset_generator(sharded, grid_config, shuffle=True, repeat=None, max_num_ensembles=None, testing=False, use_shard_nums=None, random_seed=None): if use_shard_nums is None: num_shards = sharded.get_num_shards() all_shard_nums = np.arange(num_shards) else: all_shard_nums = use_shard_nums seen = col.defaultdict(set) ensemble_count = 0 if repeat == None: repeat = 1 for epoch in range(repeat): if shuffle: p = np.random.permutation(len(all_shard_nums)) all_shard_nums = all_shard_nums[p] shard_nums = all_shard_nums for i, shard_num in enumerate(shard_nums): shard_neighbors_df = sharded.read_shard(shard_num, key='neighbors') shard_structs_df = sharded.read_shard(shard_num, key='structures') ensemble_names = shard_structs_df.ensemble.unique() if len(seen[shard_num]) == len(ensemble_names): seen[shard_num] = set() if shuffle: p = np.random.permutation(len(ensemble_names)) ensemble_names = ensemble_names[p] for j, ensemble_name in enumerate(ensemble_names): if (max_num_ensembles != None) and (ensemble_count >= max_num_ensembles): return if ensemble_name in seen[shard_num]: continue seen[shard_num].add(ensemble_name) ensemble_count += 1 ensemble_df = shard_structs_df[shard_structs_df.ensemble == ensemble_name] # Subunits names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble_df) structs_df = [udf0, udf1] if udf0 is not None else [bdf0, bdf1] # Get positives pos_neighbors_df = shard_neighbors_df[ shard_neighbors_df.ensemble0 == ensemble_name] # Get negatives neg_neighbors_df = nb.get_negatives(pos_neighbors_df, structs_df[0], structs_df[1]) # Throw away non empty hetero/insertion_code non_heteros = [] for df in structs_df: non_heteros.append( df[(df.hetero == ' ') & (df.insertion_code == ' ')].residue.unique()) pos_neighbors_df = pos_neighbors_df[pos_neighbors_df.residue0.isin(non_heteros[0]) & \ pos_neighbors_df.residue1.isin(non_heteros[1])] neg_neighbors_df = neg_neighbors_df[neg_neighbors_df.residue0.isin(non_heteros[0]) & \ neg_neighbors_df.residue1.isin(non_heteros[1])] # Sample pos and neg samples num_pos = pos_neighbors_df.shape[0] num_neg = neg_neighbors_df.shape[0] num_pos_to_use, num_neg_to_use = __num_to_use( num_pos, num_neg, testing, grid_config) if shuffle: pos_neighbors_df = pos_neighbors_df.sample( frac=1).reset_index(drop=True) neg_neighbors_df = neg_neighbors_df.sample( frac=1).reset_index(drop=True) if pos_neighbors_df.shape[0] == num_pos_to_use: pos_samples_df = pos_neighbors_df.reset_index(drop=True) else: pos_samples_df = pos_neighbors_df.sample( num_pos_to_use, replace=True).reset_index(drop=True) if neg_neighbors_df.shape[0] == num_neg_to_use: neg_samples_df = neg_neighbors_df.reset_index(drop=True) else: neg_samples_df = neg_neighbors_df.sample( num_neg_to_use, replace=True).reset_index(drop=True) pos_pairs_cas = __get_res_pair_ca_coords( pos_samples_df, structs_df) neg_pairs_cas = __get_res_pair_ca_coords( neg_samples_df, structs_df) pos_features = [] for (res0, res1, center0, center1) in pos_pairs_cas: fgrid = df_to_feature(structs_df[0], structs_df[1], center0, center1, grid_config, random_seed) pos_features.append( ('{:}/{:}/{:}'.format(ensemble_name, res0, res1), fgrid, np.array([1]))) neg_features = [] for (res0, res1, center0, center1) in neg_pairs_cas: fgrid = df_to_feature(structs_df[0], structs_df[1], center0, center1, grid_config, random_seed) neg_features.append( ('{:}/{:}/{:}'.format(ensemble_name, res0, res1), fgrid, np.array([0]))) for f in util.intersperse(pos_features, neg_features): yield f
def form_scop_pair_filter_against(sharded, level): """Remove pairs that have matching scop classes in both subunits.""" scop_index = scop.get_scop_index()[level] scop_pairs = [] for _, shard in sh.iter_shards(sharded): for e, ensemble in shard.groupby(['ensemble']): names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) chains0 = bdf0[['structure', 'chain']].drop_duplicates() chains1 = bdf1[['structure', 'chain']].drop_duplicates() chains0['pdb_code'] = chains0['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) chains1['pdb_code'] = chains1['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) scop0, scop1 = [], [] for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop0.append(scop_index.loc[(pc, c)].values) for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop1.append(scop_index.loc[(pc, c)].values) scop0 = list(np.unique(np.concatenate(scop0))) \ if len(scop0) > 0 else [] scop1 = list(np.unique(np.concatenate(scop1))) \ if len(scop1) > 0 else [] pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1] scop_pairs.extend(pairs) scop_pairs = set(scop_pairs) def filter_fn(df): to_keep = {} for e, ensemble in df.groupby(['ensemble']): names, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) chains0 = bdf0[['structure', 'chain']].drop_duplicates() chains1 = bdf1[['structure', 'chain']].drop_duplicates() chains0['pdb_code'] = chains0['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) chains1['pdb_code'] = chains1['structure'].apply( lambda x: fi.get_pdb_code(x).lower()) scop0, scop1 = [], [] for (pc, c) in chains0[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop0.append(scop_index.loc[(pc, c)].values) for (pc, c) in chains1[['pdb_code', 'chain']].to_numpy(): if (pc, c) in scop_index: scop1.append(scop_index.loc[(pc, c)].values) scop0 = list(np.unique(np.concatenate(scop0))) \ if len(scop0) > 0 else [] scop1 = list(np.unique(np.concatenate(scop1))) \ if len(scop1) > 0 else [] pairs = [tuple(sorted((a, b))) for a in scop0 for b in scop1] to_keep[e] = True for p in pairs: if p in scop_pairs: to_keep[e] = False to_keep = pd.Series(to_keep)[df['ensemble']] return df[to_keep.values] return filter_fn
def compute_all_bsa_main(sharded_path, ensemble): sharded = sh.Sharded.load(sharded_path) ensemble = sharded.read_ensemble(ensemble) _, (bdf0, bdf1, udf0, udf1) = nb.get_subunits(ensemble) print(compute_bsa(bdf0, bdf1))
def _bsa_db(sharded, shard_num, output_bsa): logger.info(f'Processing shard {shard_num:}') start_time = timeit.default_timer() start_time_reading = timeit.default_timer() shard = sharded.read_shard(shard_num) elapsed_reading = timeit.default_timer() - start_time_reading start_time_waiting = timeit.default_timer() with db_sem: start_time_reading = timeit.default_timer() if os.path.exists(output_bsa): curr_bsa_db = pd.read_csv(output_bsa).set_index(['ensemble']) else: curr_bsa_db = None tmp_elapsed_reading = timeit.default_timer() - start_time_reading elapsed_waiting = timeit.default_timer() - start_time_waiting - \ tmp_elapsed_reading elapsed_reading += tmp_elapsed_reading start_time_processing = timeit.default_timer() all_results = [] cache = {} for e, ensemble in shard.groupby('ensemble'): if (curr_bsa_db is not None) and (e in curr_bsa_db.index): continue (name0, name1, _, _), (bdf0, bdf1, _, _) = nb.get_subunits(ensemble) try: # We use bound for indiviudal subunits in bsa computation, as # sometimes the actual structure between bound and unbound differ. if name0 not in cache: cache[name0] = bsa._compute_asa(bdf0) if name1 not in cache: cache[name1] = bsa._compute_asa(bdf1) result = bsa.compute_bsa(bdf0, bdf1, cache[name0], cache[name1]) result['ensemble'] = e all_results.append(result) except AssertionError as e: logger.warning(e) logger.warning(f'Failed BSA on {e:}') if len(all_results) > 0: to_add = pd.concat(all_results, axis=1).T elapsed_processing = timeit.default_timer() - start_time_processing if len(all_results) > 0: start_time_waiting = timeit.default_timer() with db_sem: start_time_writing = timeit.default_timer() # Update db in case it has updated since last run. if os.path.exists(output_bsa): curr_bsa_db = pd.read_csv(output_bsa) new_bsa_db = pd.concat([curr_bsa_db, to_add]) else: new_bsa_db = to_add new_bsa_db.to_csv(output_bsa + f'.tmp{shard_num:}', index=False) os.rename(output_bsa + f'.tmp{shard_num:}', output_bsa) elapsed_writing = timeit.default_timer() - start_time_writing elapsed_waiting += timeit.default_timer() - start_time_waiting - \ elapsed_writing else: elapsed_writing = 0 elapsed = timeit.default_timer() - start_time logger.info( f'For {len(all_results):03d} pairs buried in shard {shard_num:} spent ' f'{elapsed_reading:05.2f} reading, ' f'{elapsed_processing:05.2f} processing, ' f'{elapsed_writing:05.2f} writing, ' f'{elapsed_waiting:05.2f} waiting, and ' f'{elapsed:05.2f} overall.')