def _load_lba_data(datafiles, dist, maxnum): """ Load LBA datasets from LMDB format. :param datafiles: Dictionary of LMDB dataset directories. :type datafiles: dict :param radius: Radius of the selected region around the ligand. :type radius: float :param maxnum: Maximum total number of atoms of the ligand and the region around it. :type radius: int :return datasets: Dictionary of processed dataset objects. :rtype datasets: dict """ datasets = {} for split, datafile in datafiles.items(): dataset = LMDBDataset(datafile, transform=TransformLBA(dist, maxnum, move_lig=False)) # Load original atoms dsdict = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_pocket','atoms_ligand']) # Add the label data dsdict['neglog_aff'] = np.array([item['scores']['neglog_aff'] for item in dataset]) # Convert everything to tensors datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()} return datasets
def _load_lba_data_siamese(datafiles, dist, maxnum): """ Load LBA datasets from LMDB format. :param datafiles: Dictionary of LMDB dataset directories. :type datafiles: dict :param radius: Radius of the selected region around the mutated residue. :type radius: float :param maxnum: Maximum total number of atoms of the ligand and the region around it. :type radius: int :return datasets: Dictionary of processed dataset objects. :rtype datasets: dict """ datasets = {} key_names = ['index', 'num_atoms', 'charges', 'positions'] for split, datafile in datafiles.items(): dataset = LMDBDataset(datafile, transform=TransformLBA(dist, maxnum, move_lig=True)) # Load original atoms bound = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_pocket','atoms_ligand']) for k in key_names: bound['bound_'+k] = bound.pop(k) # Load mutated atoms apart = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_pocket','atoms_ligand_moved']) for k in key_names: apart['apart_'+k] = apart.pop(k) # Merge datasets with atoms dsdict = {**bound, **apart} # Add the label data dsdict['neglog_aff'] = np.array([item['scores']['neglog_aff'] for item in dataset]) # Convert everything to tensors datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()} return datasets
def _load_msp_data(datafiles, radius, droph=False): """ Load MSP datasets from LMDB format. :param datafiles: Dictionary of LMDB dataset directories. :type datafiles: dict :param radius: Radius of the selected region around the mutated residue. :type radius: float :param radius: Drop hydrogen atoms. :type radius: bool :return datasets: Dictionary of processed dataset objects. :rtype datasets: dict """ datasets = {} key_names = ['index', 'num_atoms', 'charges', 'positions'] for split, datafile in datafiles.items(): dataset = LMDBDataset(datafile, transform=EnvironmentSelection(radius, droph)) # Load original atoms ori = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['original_atoms']) for k in key_names: ori['original_'+k] = ori.pop(k) # Load mutated atoms mut = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['mutated_atoms']) for k in key_names: mut['mutated_'+k] = mut.pop(k) # Merge datasets with atoms dsdict = {**ori, **mut} # Add labels labels = [dataset[i]['label'] for i in range(len(dataset))] dsdict['label'] = np.array(labels, dtype=int) # Convert everything to tensors datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()} return datasets
def __init__(self, lmdb_path, testing, random_seed=None, **kwargs): self._lmdb_dataset = LMDBDataset(lmdb_path) self.testing = testing self.random_seed = random_seed self.grid_config = dotdict({ # Mapping from elements to position in channel dimension. 'element_mapping': { 'C': 0, 'O': 1, 'N': 2, 'S': 3 }, # Radius of the grids to generate, in angstroms. 'radius': 17.0, # Resolution of each voxel, in angstroms. 'resolution': 1.0, # Number of directions to apply for data augmentation. 'num_directions': 20, # Number of rolls to apply for data augmentation. 'num_rolls': 20, ### PPI specific # Number of negatives to sample per positive example. -1 means all. 'neg_to_pos_ratio': 1, 'neg_to_pos_ratio_testing': 1, # Max number of positive regions to take from a structure. -1 means all. 'max_pos_regions_per_ensemble': 5, 'max_pos_regions_per_ensemble_testing': 5, # Whether to use all negative at test time. 'full_test': False, }) # Update grid configs as necessary self.grid_config.update(kwargs)
def _load_smp_data(datafiles): """ Load SMP datasets from LMDB format. :param datafiles: Dictionary of LMDB dataset directories. :type datafiles: dict :param radius: Radius of the selected region around the ligand. :type radius: float :param maxnum: Maximum total number of atoms of the ligand and the region around it. :type radius: int :return datasets: Dictionary of processed dataset objects. :rtype datasets: dict """ datasets = {} for split, datafile in datafiles.items(): dataset = LMDBDataset(datafile) # Load original atoms dsdict = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms']) # Add the label data labels = np.zeros([len(label_names),len(dataset)]) for i, item in enumerate(dataset): labels[:,i] = item['labels'] for j, label in enumerate(label_names): dsdict[label] = labels[j] # Convert everything to tensors datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()} return datasets
def _load_res_data(datafiles, maxnum, samples=42, keep=['C', 'N', 'O', 'P', 'S'], seed=1): """ Load RES datasets from LMDB format. :param datafiles: Dictionary of LMDB dataset directories. :type datafiles: dict :param radius: Maximum number of atoms to consider. :type radius: int :param samples: Number of sample structures. :type samples: int :return datasets: Dictionary of processed dataset objects. :rtype datasets: dict """ datasets = {} key_names = ['index', 'num_atoms', 'charges', 'positions'] for split, datafile in datafiles.items(): dataset = LMDBDataset(datafile) # Randomly pick the samples. numpy.random.seed(seed) indices = np.random.choice(np.arange(len(dataset)), size=int(samples), replace=True) # Get labels labels = np.concatenate( [dataset[i]['labels']['label'] for i in indices]) print('Labels:', labels) # Load original atoms dsdict = _extract_coordinates_as_numpy_arrays(dataset, indices=indices, keep=keep) for k in key_names: dsdict[k] = dsdict.pop(k) # Add labels dsdict['label'] = np.array(labels, dtype=int) print('Sizes:') for key, val in dsdict.items(): print(key, len(val)) # Convert everything to tensors datasets[split] = { key: torch.from_numpy(val) for key, val in dsdict.items() } return datasets
def _load_lep_data(datafiles, radius, droph, maxnum): """ Load LEP datasets from LMDB format. :param datafiles: Dictionary of LMDB dataset directories. :type datafiles: dict :param radius: Radius of the selected region around the mutated residue. :type radius: float :param radius: Drop hydrogen atoms. :type radius: bool :param radius: Maximum number of atoms to consider. :type radius: int :return datasets: Dictionary of processed dataset objects. :rtype datasets: dict """ datasets = {} key_names = ['index', 'num_atoms', 'charges', 'positions'] for split, datafile in datafiles.items(): dataset = LMDBDataset(datafile, transform=EnvironmentSelection(radius, droph, maxnum)) # Load original atoms act = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_active']) for k in key_names: act[k+'_active'] = act.pop(k) # Load mutated atoms ina = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_inactive']) for k in key_names: ina[k+'_inactive'] = ina.pop(k) # Merge datasets with atoms dsdict = {**act, **ina} # Add labels (1 for active, 0 for inactive) ldict = {'A':1, 'I':0} labels = [ldict[dataset[i]['label']] for i in range(len(dataset))] dsdict['label'] = np.array(labels, dtype=int) # Convert everything to tensors datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()} return datasets
def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformSMP(args.target_name)) val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformSMP(args.target_name)) test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformSMP(args.target_name)) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) for data in train_loader: num_features = data.num_features break model = GNN_SMP(num_features, dim=args.hidden_dim).to(device) model.to(device) best_val_loss = 999 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, min_lr=0.00001) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) print('validating...') val_loss, _, _ = test(model, val_loader, device) scheduler.step(val_loss) if val_loss < best_val_loss: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) best_val_loss = val_loss elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print('\tTrain Loss: {:.7f}, Val MAE: {:.7f}'.format( train_loss, val_loss)) if test_mode: train_file = os.path.join(log_dir, f'smp-rep{rep}.best.train.pt') val_file = os.path.join(log_dir, f'smp-rep{rep}.best.val.pt') test_file = os.path.join(log_dir, f'smp-rep{rep}.best.test.pt') cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) model.load_state_dict(cpt['model_state_dict']) _, y_true_train, y_pred_train = test(model, train_loader, device) torch.save({ 'targets': y_true_train, 'predictions': y_pred_train }, train_file) _, y_true_val, y_pred_val = test(model, val_loader, device) torch.save({ 'targets': y_true_val, 'predictions': y_pred_val }, val_file) mae, y_true_test, y_pred_test = test(model, test_loader, device) print(f'\tTest MAE {mae}') torch.save({ 'targets': y_true_test, 'predictions': y_pred_test }, test_file)
def train(args, device, test_mode=False): print("Training model with config:") print(str(json.dumps(args.__dict__, indent=4)) + "\n") # Save config with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(args.__dict__, f, indent=4) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) train_dataset = LMDBDataset( os.path.join(args.data_dir, 'train'), transform=CNN3D_TransformRSR(random_seed=args.random_seed)) val_dataset = LMDBDataset( os.path.join(args.data_dir, 'val'), transform=CNN3D_TransformRSR(random_seed=args.random_seed)) test_dataset = LMDBDataset( os.path.join(args.data_dir, 'test'), transform=CNN3D_TransformRSR(random_seed=args.random_seed)) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False) for data in train_loader: in_channels, spatial_size = data['feature'].size()[1:3] print('num channels: {:}, spatial size: {:}'.format( in_channels, spatial_size)) break model = conv_model(in_channels, spatial_size, args) print(model) model.to(device) best_val_loss = np.Inf best_corrs = None optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) val_loss, corrs, val_df = test(model, val_loader, device) if val_loss < best_val_loss: print( f"\nSave model at epoch {epoch:03d}, val_loss: {val_loss:.4f}") save_weights(model, os.path.join(args.output_dir, f'best_weights.pt')) best_val_loss = val_loss best_corrs = corrs elapsed = (time.time() - start) print('Epoch {:03d} finished in : {:.3f} s'.format(epoch, elapsed)) print( '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}' .format(train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) if test_mode: model.load_state_dict( torch.load(os.path.join(args.output_dir, f'best_weights.pt'))) rmse, corrs, test_df = test(model, test_loader, device) test_df.to_pickle(os.path.join(args.output_dir, 'test_results.pkl')) print( 'Test RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}' .format(rmse, corrs['per_target_spearman'], corrs['all_spearman'])) test_file = os.path.join(args.output_dir, f'test_results.txt') with open(test_file, 'a+') as out: out.write('{}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format( args.random_seed, rmse, corrs['per_target_spearman'], corrs['all_spearman'])) return best_val_loss, best_corrs['per_target_spearman'], best_corrs[ 'all_spearman']
graph.mut_idx[total_n_mut:total_n_mut + graph.num_mut_atoms[i + 1]] += total_n return graph def __call__(self, batch): return self.collate(batch) if __name__ == "__main__": from tqdm import tqdm # dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data', 'train')) # dataloader = DataLoader(dataset, batch_size=3, shuffle=False, num_workers=4) # for i, item in tqdm(enumerate(dataloader)): # if i < 578: # continue # print(item) dataset = LMDBDataset(os.path.join( '/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data', 'train'), transform=GNNTransformMSP()) dataloader = DataLoader(dataset, batch_size=3, shuffle=False, collate_fn=CollaterMSP(batch_size=3), num_workers=4) for original, mutated in tqdm(dataloader): if mutated.mut_idx.max() > mutated.batch.shape[0]: print(mutated.batch.shape) print(mutated.mut_idx) break
# Assume labels/classes are integers (0, 1, 2, ...). labels = [item['label'] for item in dataset] classes, class_sample_count = np.unique(labels, return_counts=True) # Weighted sampler for imbalanced classification (1:1 ratio for each class) weight = 1. / class_sample_count sample_weights = torch.tensor([weight[t] for t in labels]) sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(dataset), replacement=True) return sampler if __name__ == "__main__": dataset_path = os.path.join(os.environ['MSP_DATA'], 'val') dataset = LMDBDataset(dataset_path, transform=CNN3D_TransformMSP(add_flag=True, center_at_mut=True, radius=10.0)) dataloader = DataLoader(dataset, batch_size=8, sampler=create_balanced_sampler(dataset)) non_zeros = [] for item in dataloader: print('feature original shape:', item['feature_original'].shape) print('feature mutated shape:', item['feature_mutated'].shape) print('label:', item['label']) print('id:', item['id']) non_zeros.append(np.count_nonzero(item['label'])) break for item in dataloader: non_zeros.append(np.count_nonzero(item['label']))
item = prot_graph_transform(item, ['atoms'], 'scores') graph = item['atoms'] graph.y = torch.FloatTensor([graph.y['rms']]) split = item['id'].split("'") graph.target = split[1] graph.decoy = split[3] return graph if __name__ == "__main__": save_dir = '/scratch/users/aderry/atom3d/rsr' data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/RSR/splits/candidates-split-by-time/data' os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=GNNTransformRSR()) val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformRSR()) test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformRSR()) # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) # for item in dataset[0]: # print(item, type(dataset[0][item])) print('processing train dataset...') for i, item in enumerate(tqdm(train_dataset)): torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) print('processing validation dataset...')
import numpy as np import os from atom3d.util.transforms import prot_graph_transform, PairedGraphTransform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, DataLoader import atom3d.util.graph as gr class GNNTransformLEP(object): def __init__(self, atom_keys, label_key): self.atom_keys = atom_keys self.label_key = label_key def __call__(self, item): # transform protein and/or pocket to PTG graphs item = prot_graph_transform(item, atom_keys=self.atom_keys, label_key=self.label_key) return item if __name__=="__main__": dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/LEP/splits/split-by-protein/data', 'train'), transform=PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label')) dataloader = DataLoader(dataset, batch_size=4, shuffle=False) for active, inactive in dataloader: print(active) print(inactive) break # for item in dataloader: # print(item) # break
# Assume labels/classes are integers (0, 1, 2, ...). labels = [item['label'] for item in dataset] classes, class_sample_count = np.unique(labels, return_counts=True) # Weighted sampler for imbalanced classification (1:1 ratio for each class) weight = 1. / class_sample_count sample_weights = torch.tensor([weight[t] for t in labels]) sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights, num_samples=len(dataset), replacement=True) return sampler if __name__ == "__main__": dataset_path = os.path.join(os.environ['LEP_DATA'], 'val') dataset = LMDBDataset(dataset_path, transform=CNN3D_TransformLEP(add_flag=True, radius=10.0)) dataloader = DataLoader(dataset, batch_size=8, sampler=create_balanced_sampler(dataset)) non_zeros = [] for item in dataloader: print('feature inactive shape:', item['feature_inactive'].shape) print('feature active shape:', item['feature_active'].shape) print('label:', item['label']) print('id:', item['id']) non_zeros.append(np.count_nonzero(item['label'])) break for item in dataloader: non_zeros.append(np.count_nonzero(item['label']))
import numpy as np import os import torch from atom3d.util.transforms import prot_graph_transform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, DataLoader class GNNTransformPSR(object): def __init__(self): pass def __call__(self, item): item = prot_graph_transform(item, ['atoms'], 'scores') graph = item['atoms'] graph.y = torch.FloatTensor([graph.y['gdt_ts']]) graph.target = item['id'][0] graph.decoy = item['id'][1] return graph if __name__=="__main__": dataset = LMDBDataset('/scratch/users/raphtown/atom3d_mirror/lmdb/PSR/splits/split-by-year/data/train', transform=GNNTransformPSR()) dataloader = DataLoader(dataset, batch_size=4, shuffle=False) # for item in dataset[0]: # print(item, type(dataset[0][item])) for item in dataloader: print(item) break
import os import torch from atom3d.util.transforms import prot_graph_transform from atom3d.datasets import LMDBDataset from torch_geometric.data import Data, Dataset, DataLoader class GNNTransformRSR(object): def __init__(self): pass def __call__(self, item): item = prot_graph_transform(item, ['atoms'], 'scores') graph = item['atoms'] graph.y = torch.FloatTensor([graph.y['rms']]) graph.target = item['id'][0] graph.decoy = item['id'][1] return graph if __name__ == "__main__": dataset = LMDBDataset( '/scratch/users/raphtown/atom3d_mirror/lmdb/RSR/splits/candidates-split-by-time/data/train', transform=GNNTransformRSR()) dataloader = DataLoader(dataset, batch_size=4, shuffle=False) # for item in dataset[0]: # print(item, type(dataset[0][item])) for item in dataloader: print(item) break
def train(args, device, test_mode=False): print("Training model with config:") print(str(json.dumps(args.__dict__, indent=4)) + "\n") # Save config with open(os.path.join(args.output_dir, 'config.json'), 'w') as f: json.dump(args.__dict__, f, indent=4) np.random.seed(args.random_seed) torch.manual_seed(args.random_seed) train_dataset = LMDBDataset( os.path.join(args.data_dir, 'train'), transform=CNN3D_TransformMSP(args.add_flag, args.center_at_mut, random_seed=args.random_seed)) val_dataset = LMDBDataset( os.path.join(args.data_dir, 'val'), transform=CNN3D_TransformMSP(args.add_flag, args.center_at_mut, random_seed=args.random_seed)) test_dataset = LMDBDataset( os.path.join(args.data_dir, 'test'), transform=CNN3D_TransformMSP(args.add_flag, args.center_at_mut, random_seed=args.random_seed)) train_loader = DataLoader(train_dataset, args.batch_size, sampler=create_balanced_sampler(train_dataset)) val_loader = DataLoader(val_dataset, args.batch_size, sampler=create_balanced_sampler(val_dataset)) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False) for data in train_loader: in_channels, spatial_size = data['feature_original'].size()[1:3] print('num channels: {:}, spatial size: {:}'.format(in_channels, spatial_size)) break model = conv_model(in_channels, spatial_size, args) print(model) model.to(device) prev_val_loss = np.Inf best_val_loss = np.Inf best_val_auroc = 0 best_stats = None criterion = nn.BCELoss() criterion.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) for epoch in range(1, args.num_epochs+1): start = time.time() train_loss = train_loop(model, train_loader, args.repeat_gen, criterion, optimizer, device) val_loss, stats, val_df = test(model, val_loader, args.repeat_gen, criterion, device) elapsed = (time.time() - start) print(f'Epoch {epoch:03d} finished in : {elapsed:.3f} s') print(f"\tTrain loss {train_loss:.4f}, Val loss: {val_loss:.4f}, " f"Val AUROC: {stats['auroc']:.4f}, Val AUPRC: {stats['auroc']:.4f}") #if stats['auroc'] > best_val_auroc: if val_loss < best_val_loss: print(f"\nSave model at epoch {epoch:03d}, val_loss: {val_loss:.4f}, " f"auroc: {stats['auroc']:.4f}, auprc: {stats['auroc']:.4f}") save_weights(model, os.path.join(args.output_dir, f'best_weights.pt')) best_val_loss = val_loss best_val_auroc = stats['auroc'] best_stats = stats if args.early_stopping and val_loss >= prev_val_loss: print(f"Validation loss stopped decreasing, stopping at epoch {epoch:03d}...") break prev_val_loss = val_loss if test_mode: model.load_state_dict(torch.load(os.path.join(args.output_dir, f'best_weights.pt'))) test_loss, stats, test_df = test(model, test_loader, args.repeat_gen, criterion, device) test_df.to_pickle(os.path.join(args.output_dir, 'test_results.pkl')) print(f"Test loss: {test_loss:.4f}, Test AUROC: {stats['auroc']:.4f}, " f"Test AUPRC: {stats['auprc']:.4f}") test_file = os.path.join(args.output_dir, f'test_results.txt') with open(test_file, 'w') as f: f.write(f"test_loss\tAUROC\tAUPRC\n") f.write(f"{test_loss:}\t{stats['auroc']:}\t{stats['auprc']:}\n")
pass def __call__(self, data_list): batch_1 = Batch.from_data_list([d[0] for d in data_list]) batch_2 = Batch.from_data_list([d[1] for d in data_list]) return batch_1, batch_2 if __name__=="__main__": save_dir = '/scratch/users/aderry/atom3d/lep' data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/LEP/splits/split-by-protein/data' os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) transform = PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label') train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=transform) val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=transform) test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=transform) # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) # for item in dataset[0]: # print(item, type(dataset[0][item])) for i, item in enumerate(tqdm(train_dataset)): torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) for i, item in enumerate(tqdm(val_dataset)): torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt')) for i, item in enumerate(tqdm(test_dataset)):
def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformRSR()) val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformRSR()) test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformRSR()) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) for data in train_loader: num_features = data.num_features break model = GNN_RSR(num_features, hidden_dim=args.hidden_dim).to(device) model.to(device) best_val_loss = 999 best_rs = 0 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) val_loss, corrs, results_df = test(model, val_loader, device) if corrs['all_spearman'] > best_rs: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(log_dir, f'best_weights.pt')) best_rs = corrs['all_spearman'] elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print( '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}' .format(train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) if test_mode: test_file = os.path.join(log_dir, f'rsr_rep{rep}.csv') model.load_state_dict( torch.load(os.path.join(log_dir, f'best_weights.pt'))) val_loss, corrs, results_df = test(model, test_loader, device) # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png')) print( '\tTest RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}' .format(train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) pd.to_csv(results_df, test_file, index=False)
def __init__(self, lmdb_path): self.dataset = LMDBDataset(lmdb_path)
def train(args, device, log_dir, seed=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformPSR()) val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformPSR()) test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformPSR()) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) for data in train_loader: num_features = data.num_features break model = GNN_PSR(num_features, hidden_dim=args.hidden_dim).to(device) model.to(device) best_val_loss = 999 best_rp = 0 best_rs = 0 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=3, min_lr=0.00001) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) print('validating...') val_loss, corrs, test_df = test(model, val_loader, device) scheduler.step(val_loss) if corrs['all_spearman'] > best_rs: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(log_dir, f'best_weights.pt')) best_rs = corrs['all_spearman'] elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print( '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}' .format(train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman'])) if test_mode: test_file = os.path.join(log_dir, f'test_results.txt') model.load_state_dict( torch.load(os.path.join(log_dir, f'best_weights.pt'))) test_loss, corrs, test_df = test(model, val_loader, device) print( 'Test RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}' .format(test_loss, corrs['per_target_spearman'], corrs['all_spearman'])) test_df.to_csv(test_file)
# Last dimension is atom channel, so we need to move it to the front # per pytroch style grid = np.moveaxis(grid, -1, 0) return grid def __call__(self, item): # Transform protein into voxel grids. # Apply random rotation matrix. id = eval(item['id']) transformed = { 'feature': self._voxelize(item['atoms']), 'label': item['scores']['gdt_ts'], 'target': id[0], 'decoy': id[1], } return transformed if __name__ == "__main__": dataset_path = os.path.join(os.environ['PSR_DATA'], 'val') dataset = LMDBDataset(dataset_path, transform=CNN3D_TransformPSR(radius=10.0)) dataloader = DataLoader(dataset, batch_size=8, shuffle=False) for item in dataloader: print('feature shape:', item['feature'].shape) print('label:', item['label']) print('target:', item['target']) print('decoy:', item['decoy']) break
def train(args, device, log_dir, seed=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=GNNTransformLBA()) val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=GNNTransformLBA()) test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=GNNTransformLBA()) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) for data in train_loader: num_features = data.num_features break model = GNN_LBA(num_features, hidden_dim=args.hidden_dim).to(device) model.to(device) best_val_loss = 999 best_rp = 0 best_rs = 0 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) val_loss, r_p, r_s, y_true, y_pred = test(model, val_loader, device) if val_loss < best_val_loss: save_weights(model, os.path.join(log_dir, f'best_weights.pt')) # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}.png')) best_val_loss = val_loss best_rp = r_p best_rs = r_s elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print( '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}' .format(train_loss, val_loss, r_p, r_s)) # logger.info('{:03d}\t{:.7f}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(epoch, train_loss, val_loss, r_p, r_s)) if test_mode: test_file = os.path.join(log_dir, f'test_results.txt') model.load_state_dict( torch.load(os.path.join(log_dir, f'best_weights.pt'))) rmse, pearson, spearman, y_true, y_pred = test(model, test_loader, device) # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png')) print( 'Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format( rmse, pearson, spearman)) with open(test_file, 'a+') as out: out.write('{}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format( seed, rmse, pearson, spearman)) return best_val_loss, best_rp, best_rs
'h298_atom', 'g298_atom', 'cv_atom', ] self.label_mapping = {k: v for v, k in enumerate(label_mapping)} return item['labels'][self.label_mapping[name]] def __call__(self, item): # Transform molecule into voxel grids. # Apply random rotation matrix. transformed = { 'feature': self._voxelize(item['atoms']), 'label': self._lookup_label(item, self.label_name), 'id': item['id'], } return transformed if __name__ == "__main__": dataset_path = os.path.join(os.environ['SMP_DATA'], 'val') dataset = LMDBDataset(dataset_path, transform=CNN3D_TransformSMP(label_name='alpha', radius=10.0)) dataloader = DataLoader(dataset, batch_size=8, shuffle=False) for item in dataloader: print('feature shape:', item['feature'].shape) print('label:', item['label']) print('id:', item['id']) break
def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) transform = GNNTransformMSP() if args.precomputed: train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) else: train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) train_loader = DataLoader( train_dataset, args.batch_size, shuffle=True, num_workers=4, collate_fn=CollaterMSP(batch_size=args.batch_size)) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4, collate_fn=CollaterMSP(batch_size=args.batch_size)) test_loader = DataLoader( test_dataset, args.batch_size, shuffle=False, num_workers=4, collate_fn=CollaterMSP(batch_size=args.batch_size)) for original, mutated in train_loader: num_features = original.num_features break gcn_model = GNN_MSP(num_features, hidden_dim=args.hidden_dim).to(device) gcn_model.to(device) ff_model = MLP_MSP(args.hidden_dim).to(device) best_val_loss = 999 best_val_auroc = 0 params = [x for x in gcn_model.parameters() ] + [x for x in ff_model.parameters()] criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4])) criterion.to(device) optimizer = torch.optim.Adam(params, lr=args.learning_rate) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(epoch, gcn_model, ff_model, train_loader, criterion, optimizer, device) print('validating...') val_loss, auroc, auprc, _, _ = test(gcn_model, ff_model, val_loader, criterion, device) if auroc > best_val_auroc: torch.save( { 'epoch': epoch, 'gcn_state_dict': gcn_model.state_dict(), 'ff_state_dict': ff_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) best_val_auroc = auroc elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print( f'\tTrain loss {train_loss}, Val loss {val_loss}, Val AUROC {auroc}, Val auprc {auprc}' ) if test_mode: train_file = os.path.join(log_dir, f'msp-rep{rep}.best.train.pt') val_file = os.path.join(log_dir, f'msp-rep{rep}.best.val.pt') test_file = os.path.join(log_dir, f'msp-rep{rep}.best.test.pt') cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) gcn_model.load_state_dict(cpt['gcn_state_dict']) ff_model.load_state_dict(cpt['ff_state_dict']) _, _, _, y_true_train, y_pred_train = test(gcn_model, ff_model, train_loader, criterion, device) torch.save({ 'targets': y_true_train, 'predictions': y_pred_train }, train_file) _, _, _, y_true_val, y_pred_val = test(gcn_model, ff_model, val_loader, criterion, device) torch.save({ 'targets': y_true_val, 'predictions': y_pred_val }, val_file) test_loss, auroc, auprc, y_true_test, y_pred_test = test( gcn_model, ff_model, test_loader, criterion, device) print( f'\tTest loss {test_loss}, Test AUROC {auroc}, Test auprc {auprc}') torch.save({ 'targets': y_true_test, 'predictions': y_pred_test }, test_file) return test_loss, auroc, auprc return best_val_loss
graph = item['atoms'] x2 = torch.tensor(item['atom_feats'], dtype=torch.float).t().contiguous() graph.x = torch.cat([graph.x.to(torch.float), x2], dim=-1) graph.y = self._lookup_label(item, self.label_name) graph.id = item['id'] return graph if __name__ == "__main__": save_dir = '/scratch/users/aderry/atom3d/smp' data_dir = '/scratch/users/aderry/lmdb/atom3d/small_molecule_properties/splits/split-randomly/data' os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True) os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True) os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True) train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=GNNTransformSMP(label_name='mu')) # val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformSMP()) # test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformSMP()) # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4) # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4) # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4) # for item in dataset[0]: # print(item, type(dataset[0][item])) for i, item in enumerate(tqdm(train_dataset)): print(item.y) # torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt')) # for i, item in enumerate(tqdm(val_dataset)): # torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt'))
label_key='scores') else: item = prot_graph_transform( item, atom_keys=['atoms_protein', 'atoms_pocket'], label_key='scores') # transform ligand into PTG graph item = mol_graph_transform(item, 'atoms_ligand', 'scores', use_bonds=True) node_feats, edges, edge_feats, node_pos = gr.combine_graphs( item['atoms_pocket'], item['atoms_ligand']) combined_graph = Data(node_feats, edges, edge_feats, y=item['scores']['neglog_aff'], pos=node_pos) return combined_graph if __name__ == "__main__": dataset = LMDBDataset( '/scratch/users/aderry/lmdb/atom3d/lba_lmdb/splits/split-by-sequence-identity-30/data/train', transform=GNNTransformLBA()) dataloader = DataLoader(dataset, batch_size=1, shuffle=False) # for item in dataset[0]: # print(item, type(dataset[0][item])) for item in dataloader: print(item) break
def train(args, device, log_dir, rep=None, test_mode=False): # logger = logging.getLogger('lba') # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO) if args.precomputed: train_dataset = PTGDataset(os.path.join(args.data_dir, 'train')) val_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) test_dataset = PTGDataset(os.path.join(args.data_dir, 'val')) else: transform = GNNTransformLBA() train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'), transform=transform) val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'), transform=transform) test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'), transform=transform) train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4) for data in train_loader: num_features = data.num_features break model = GNN_LBA(num_features, hidden_dim=args.hidden_dim).to(device) model.to(device) best_val_loss = 999 best_rp = 0 best_rs = 0 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) for epoch in range(1, args.num_epochs + 1): start = time.time() train_loss = train_loop(model, train_loader, optimizer, device) val_loss, r_p, r_s, y_true, y_pred = test(model, val_loader, device) if val_loss < best_val_loss: torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': train_loss, }, os.path.join(log_dir, f'best_weights_rep{rep}.pt')) # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}.png')) best_val_loss = val_loss best_rp = r_p best_rs = r_s elapsed = (time.time() - start) print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed)) print( '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}' .format(train_loss, val_loss, r_p, r_s)) # logger.info('{:03d}\t{:.7f}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(epoch, train_loss, val_loss, r_p, r_s)) if test_mode: train_file = os.path.join(log_dir, f'lba-rep{rep}.best.train.pt') val_file = os.path.join(log_dir, f'lba-rep{rep}.best.val.pt') test_file = os.path.join(log_dir, f'lba-rep{rep}.best.test.pt') cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt')) model.load_state_dict(cpt['model_state_dict']) _, _, _, y_true_train, y_pred_train = test(model, train_loader, device) torch.save({ 'targets': y_true_train, 'predictions': y_pred_train }, train_file) _, _, _, y_true_val, y_pred_val = test(model, val_loader, device) torch.save({ 'targets': y_true_val, 'predictions': y_pred_val }, val_file) rmse, pearson, spearman, y_true_test, y_pred_test = test( model, test_loader, device) print( f'\tTest RMSE {rmse}, Test Pearson {pearson}, Test Spearman {spearman}' ) torch.save({ 'targets': y_true_test, 'predictions': y_pred_test }, test_file) return best_val_loss, best_rp, best_rs