Ejemplo n.º 1
0
def _load_lba_data(datafiles, dist, maxnum):
    """
    Load LBA datasets from LMDB format.

    :param datafiles: Dictionary of LMDB dataset directories.
    :type datafiles: dict
    :param radius: Radius of the selected region around the ligand.
    :type radius: float
    :param maxnum: Maximum total number of atoms of the ligand and the region around it.
    :type radius: int

    :return datasets: Dictionary of processed dataset objects.
    :rtype datasets: dict

    """
    datasets = {}
    for split, datafile in datafiles.items():
        dataset = LMDBDataset(datafile, transform=TransformLBA(dist, maxnum, move_lig=False))
        # Load original atoms
        dsdict = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_pocket','atoms_ligand'])
        # Add the label data
        dsdict['neglog_aff'] = np.array([item['scores']['neglog_aff'] for item in dataset])
        # Convert everything to tensors
        datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()}
    return datasets
Ejemplo n.º 2
0
def _load_lba_data_siamese(datafiles, dist, maxnum):
    """
    Load LBA datasets from LMDB format.

    :param datafiles: Dictionary of LMDB dataset directories.
    :type datafiles: dict
    :param radius: Radius of the selected region around the mutated residue.
    :type radius: float
    :param maxnum: Maximum total number of atoms of the ligand and the region around it.
    :type radius: int

    :return datasets: Dictionary of processed dataset objects.
    :rtype datasets: dict

    """
    datasets = {}
    key_names = ['index', 'num_atoms', 'charges', 'positions']
    for split, datafile in datafiles.items():
        dataset = LMDBDataset(datafile, transform=TransformLBA(dist, maxnum, move_lig=True))
        # Load original atoms
        bound = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_pocket','atoms_ligand'])
        for k in key_names: bound['bound_'+k] = bound.pop(k)
        # Load mutated atoms
        apart = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_pocket','atoms_ligand_moved'])
        for k in key_names: apart['apart_'+k] = apart.pop(k)
        # Merge datasets with atoms
        dsdict = {**bound, **apart}
        # Add the label data
        dsdict['neglog_aff'] = np.array([item['scores']['neglog_aff'] for item in dataset])
        # Convert everything to tensors
        datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()}
    return datasets
Ejemplo n.º 3
0
def _load_msp_data(datafiles, radius, droph=False):
    """
    Load MSP datasets from LMDB format.

    :param datafiles: Dictionary of LMDB dataset directories.
    :type datafiles: dict
    :param radius: Radius of the selected region around the mutated residue.
    :type radius: float
    :param radius: Drop hydrogen atoms.
    :type radius: bool

    :return datasets: Dictionary of processed dataset objects.
    :rtype datasets: dict

    """
    datasets = {}
    key_names = ['index', 'num_atoms', 'charges', 'positions']
    for split, datafile in datafiles.items():
        dataset = LMDBDataset(datafile, transform=EnvironmentSelection(radius, droph))
        # Load original atoms
        ori = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['original_atoms'])
        for k in key_names: ori['original_'+k] = ori.pop(k)
        # Load mutated atoms
        mut = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['mutated_atoms'])
        for k in key_names: mut['mutated_'+k] = mut.pop(k)
        # Merge datasets with atoms
        dsdict = {**ori, **mut}
        # Add labels
        labels = [dataset[i]['label'] for i in range(len(dataset))]
        dsdict['label'] = np.array(labels, dtype=int)
        # Convert everything to tensors
        datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()}
    return datasets
Ejemplo n.º 4
0
    def __init__(self, lmdb_path, testing, random_seed=None, **kwargs):
        self._lmdb_dataset = LMDBDataset(lmdb_path)
        self.testing = testing
        self.random_seed = random_seed
        self.grid_config = dotdict({
            # Mapping from elements to position in channel dimension.
            'element_mapping': {
                'C': 0,
                'O': 1,
                'N': 2,
                'S': 3
            },
            # Radius of the grids to generate, in angstroms.
            'radius': 17.0,
            # Resolution of each voxel, in angstroms.
            'resolution': 1.0,
            # Number of directions to apply for data augmentation.
            'num_directions': 20,
            # Number of rolls to apply for data augmentation.
            'num_rolls': 20,

            ### PPI specific
            # Number of negatives to sample per positive example. -1 means all.
            'neg_to_pos_ratio': 1,
            'neg_to_pos_ratio_testing': 1,
            # Max number of positive regions to take from a structure. -1 means all.
            'max_pos_regions_per_ensemble': 5,
            'max_pos_regions_per_ensemble_testing': 5,
            # Whether to use all negative at test time.
            'full_test': False,
        })
        # Update grid configs as necessary
        self.grid_config.update(kwargs)
Ejemplo n.º 5
0
def _load_smp_data(datafiles):
    """
    Load SMP datasets from LMDB format.

    :param datafiles: Dictionary of LMDB dataset directories.
    :type datafiles: dict
    :param radius: Radius of the selected region around the ligand.
    :type radius: float
    :param maxnum: Maximum total number of atoms of the ligand and the region around it.
    :type radius: int

    :return datasets: Dictionary of processed dataset objects.
    :rtype datasets: dict

    """
    datasets = {}
    for split, datafile in datafiles.items():
        dataset = LMDBDataset(datafile)
        # Load original atoms
        dsdict = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms'])
        # Add the label data
        labels = np.zeros([len(label_names),len(dataset)])
        for i, item in enumerate(dataset):
            labels[:,i] = item['labels']
        for j, label in enumerate(label_names):
            dsdict[label] = labels[j]
        # Convert everything to tensors
        datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()}
    return datasets
Ejemplo n.º 6
0
def _load_res_data(datafiles,
                   maxnum,
                   samples=42,
                   keep=['C', 'N', 'O', 'P', 'S'],
                   seed=1):
    """
    Load RES datasets from LMDB format.

    :param datafiles: Dictionary of LMDB dataset directories.
    :type datafiles: dict
    :param radius: Maximum number of atoms to consider.
    :type radius: int
    :param samples: Number of sample structures.
    :type samples: int
    :return datasets: Dictionary of processed dataset objects.
    :rtype datasets: dict

    """
    datasets = {}
    key_names = ['index', 'num_atoms', 'charges', 'positions']
    for split, datafile in datafiles.items():
        dataset = LMDBDataset(datafile)
        # Randomly pick the samples.
        numpy.random.seed(seed)
        indices = np.random.choice(np.arange(len(dataset)),
                                   size=int(samples),
                                   replace=True)
        # Get labels
        labels = np.concatenate(
            [dataset[i]['labels']['label'] for i in indices])
        print('Labels:', labels)
        # Load original atoms
        dsdict = _extract_coordinates_as_numpy_arrays(dataset,
                                                      indices=indices,
                                                      keep=keep)
        for k in key_names:
            dsdict[k] = dsdict.pop(k)
        # Add labels
        dsdict['label'] = np.array(labels, dtype=int)
        print('Sizes:')
        for key, val in dsdict.items():
            print(key, len(val))
        # Convert everything to tensors
        datasets[split] = {
            key: torch.from_numpy(val)
            for key, val in dsdict.items()
        }
    return datasets
Ejemplo n.º 7
0
def _load_lep_data(datafiles, radius, droph, maxnum):
    """
    Load LEP datasets from LMDB format.

    :param datafiles: Dictionary of LMDB dataset directories.
    :type datafiles: dict
    :param radius: Radius of the selected region around the mutated residue.
    :type radius: float
    :param radius: Drop hydrogen atoms.
    :type radius: bool
    :param radius: Maximum number of atoms to consider.
    :type radius: int

    :return datasets: Dictionary of processed dataset objects.
    :rtype datasets: dict

    """
    datasets = {}
    key_names = ['index', 'num_atoms', 'charges', 'positions']
    for split, datafile in datafiles.items():
        dataset = LMDBDataset(datafile, transform=EnvironmentSelection(radius, droph, maxnum))
        # Load original atoms
        act = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_active'])
        for k in key_names: act[k+'_active'] = act.pop(k)
        # Load mutated atoms
        ina = extract_coordinates_as_numpy_arrays(dataset, atom_frames=['atoms_inactive'])
        for k in key_names: ina[k+'_inactive'] = ina.pop(k)
        # Merge datasets with atoms
        dsdict = {**act, **ina}
        # Add labels (1 for active, 0 for inactive)
        ldict = {'A':1, 'I':0}
        labels = [ldict[dataset[i]['label']] for i in range(len(dataset))]
        dsdict['label'] = np.array(labels, dtype=int)
        # Convert everything to tensors
        datasets[split] = {key: torch.from_numpy(val) for key, val in dsdict.items()}
    return datasets
Ejemplo n.º 8
0
def train(args, device, log_dir, rep=None, test_mode=False):
    # logger = logging.getLogger('lba')
    # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)

    train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                transform=GNNTransformSMP(args.target_name))
    val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                              transform=GNNTransformSMP(args.target_name))
    test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                               transform=GNNTransformSMP(args.target_name))

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4)
    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=4)

    for data in train_loader:
        num_features = data.num_features
        break

    model = GNN_SMP(num_features, dim=args.hidden_dim).to(device)
    model.to(device)

    best_val_loss = 999

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.7,
                                                           patience=3,
                                                           min_lr=0.00001)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        print('validating...')
        val_loss, _, _ = test(model, val_loader, device)
        scheduler.step(val_loss)
        if val_loss < best_val_loss:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                }, os.path.join(log_dir, f'best_weights_rep{rep}.pt'))
            best_val_loss = val_loss
        elapsed = (time.time() - start)
        print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed))
        print('\tTrain Loss: {:.7f}, Val MAE: {:.7f}'.format(
            train_loss, val_loss))

    if test_mode:
        train_file = os.path.join(log_dir, f'smp-rep{rep}.best.train.pt')
        val_file = os.path.join(log_dir, f'smp-rep{rep}.best.val.pt')
        test_file = os.path.join(log_dir, f'smp-rep{rep}.best.test.pt')
        cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt'))
        model.load_state_dict(cpt['model_state_dict'])
        _, y_true_train, y_pred_train = test(model, train_loader, device)
        torch.save({
            'targets': y_true_train,
            'predictions': y_pred_train
        }, train_file)
        _, y_true_val, y_pred_val = test(model, val_loader, device)
        torch.save({
            'targets': y_true_val,
            'predictions': y_pred_val
        }, val_file)
        mae, y_true_test, y_pred_test = test(model, test_loader, device)
        print(f'\tTest MAE {mae}')
        torch.save({
            'targets': y_true_test,
            'predictions': y_pred_test
        }, test_file)
Ejemplo n.º 9
0
def train(args, device, test_mode=False):
    print("Training model with config:")
    print(str(json.dumps(args.__dict__, indent=4)) + "\n")

    # Save config
    with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
        json.dump(args.__dict__, f, indent=4)

    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    train_dataset = LMDBDataset(
        os.path.join(args.data_dir, 'train'),
        transform=CNN3D_TransformRSR(random_seed=args.random_seed))
    val_dataset = LMDBDataset(
        os.path.join(args.data_dir, 'val'),
        transform=CNN3D_TransformRSR(random_seed=args.random_seed))
    test_dataset = LMDBDataset(
        os.path.join(args.data_dir, 'test'),
        transform=CNN3D_TransformRSR(random_seed=args.random_seed))

    train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, args.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False)

    for data in train_loader:
        in_channels, spatial_size = data['feature'].size()[1:3]
        print('num channels: {:}, spatial size: {:}'.format(
            in_channels, spatial_size))
        break

    model = conv_model(in_channels, spatial_size, args)
    print(model)
    model.to(device)

    best_val_loss = np.Inf
    best_corrs = None

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        val_loss, corrs, val_df = test(model, val_loader, device)
        if val_loss < best_val_loss:
            print(
                f"\nSave model at epoch {epoch:03d}, val_loss: {val_loss:.4f}")
            save_weights(model,
                         os.path.join(args.output_dir, f'best_weights.pt'))
            best_val_loss = val_loss
            best_corrs = corrs
        elapsed = (time.time() - start)
        print('Epoch {:03d} finished in : {:.3f} s'.format(epoch, elapsed))
        print(
            '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'
            .format(train_loss, val_loss, corrs['per_target_spearman'],
                    corrs['all_spearman']))

    if test_mode:
        model.load_state_dict(
            torch.load(os.path.join(args.output_dir, f'best_weights.pt')))
        rmse, corrs, test_df = test(model, test_loader, device)
        test_df.to_pickle(os.path.join(args.output_dir, 'test_results.pkl'))
        print(
            'Test RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'
            .format(rmse, corrs['per_target_spearman'], corrs['all_spearman']))
        test_file = os.path.join(args.output_dir, f'test_results.txt')
        with open(test_file, 'a+') as out:
            out.write('{}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(
                args.random_seed, rmse, corrs['per_target_spearman'],
                corrs['all_spearman']))

    return best_val_loss, best_corrs['per_target_spearman'], best_corrs[
        'all_spearman']
Ejemplo n.º 10
0
            graph.mut_idx[total_n_mut:total_n_mut +
                          graph.num_mut_atoms[i + 1]] += total_n
        return graph

    def __call__(self, batch):
        return self.collate(batch)


if __name__ == "__main__":
    from tqdm import tqdm
    # dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data', 'train'))
    # dataloader = DataLoader(dataset, batch_size=3, shuffle=False, num_workers=4)
    # for i, item in tqdm(enumerate(dataloader)):
    #     if i < 578:
    #         continue
    #     print(item)

    dataset = LMDBDataset(os.path.join(
        '/scratch/users/raphtown/atom3d_mirror/lmdb/MSP/splits/split-by-sequence-identity-30/data',
        'train'),
                          transform=GNNTransformMSP())
    dataloader = DataLoader(dataset,
                            batch_size=3,
                            shuffle=False,
                            collate_fn=CollaterMSP(batch_size=3),
                            num_workers=4)
    for original, mutated in tqdm(dataloader):
        if mutated.mut_idx.max() > mutated.batch.shape[0]:
            print(mutated.batch.shape)
            print(mutated.mut_idx)
            break
Ejemplo n.º 11
0
    # Assume labels/classes are integers (0, 1, 2, ...).
    labels = [item['label'] for item in dataset]
    classes, class_sample_count = np.unique(labels, return_counts=True)
    # Weighted sampler for imbalanced classification (1:1 ratio for each class)
    weight = 1. / class_sample_count
    sample_weights = torch.tensor([weight[t] for t in labels])
    sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights,
                                                     num_samples=len(dataset),
                                                     replacement=True)
    return sampler


if __name__ == "__main__":
    dataset_path = os.path.join(os.environ['MSP_DATA'], 'val')
    dataset = LMDBDataset(dataset_path,
                          transform=CNN3D_TransformMSP(add_flag=True,
                                                       center_at_mut=True,
                                                       radius=10.0))
    dataloader = DataLoader(dataset,
                            batch_size=8,
                            sampler=create_balanced_sampler(dataset))

    non_zeros = []
    for item in dataloader:
        print('feature original shape:', item['feature_original'].shape)
        print('feature mutated shape:', item['feature_mutated'].shape)
        print('label:', item['label'])
        print('id:', item['id'])
        non_zeros.append(np.count_nonzero(item['label']))
        break
    for item in dataloader:
        non_zeros.append(np.count_nonzero(item['label']))
Ejemplo n.º 12
0
        item = prot_graph_transform(item, ['atoms'], 'scores')
        graph = item['atoms']
        graph.y = torch.FloatTensor([graph.y['rms']])
        split = item['id'].split("'")
        graph.target = split[1]
        graph.decoy = split[3]
        return graph


if __name__ == "__main__":
    save_dir = '/scratch/users/aderry/atom3d/rsr'
    data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/RSR/splits/candidates-split-by-time/data'
    os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True)
    train_dataset = LMDBDataset(os.path.join(data_dir, 'train'),
                                transform=GNNTransformRSR())
    val_dataset = LMDBDataset(os.path.join(data_dir, 'val'),
                              transform=GNNTransformRSR())
    test_dataset = LMDBDataset(os.path.join(data_dir, 'test'),
                               transform=GNNTransformRSR())

    # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4)
    # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4)
    # for item in dataset[0]:
    #     print(item, type(dataset[0][item]))
    print('processing train dataset...')
    for i, item in enumerate(tqdm(train_dataset)):
        torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt'))

    print('processing validation dataset...')
Ejemplo n.º 13
0
import numpy as np
import os
from atom3d.util.transforms import prot_graph_transform, PairedGraphTransform
from atom3d.datasets import LMDBDataset
from torch_geometric.data import Data, Dataset, DataLoader
import atom3d.util.graph as gr

    
class GNNTransformLEP(object):
    def __init__(self, atom_keys, label_key):
        self.atom_keys = atom_keys
        self.label_key = label_key
    
    def __call__(self, item):
        # transform protein and/or pocket to PTG graphs
        item = prot_graph_transform(item, atom_keys=self.atom_keys, label_key=self.label_key)
        
        return item
    

        
if __name__=="__main__":
    dataset = LMDBDataset(os.path.join('/scratch/users/raphtown/atom3d_mirror/lmdb/LEP/splits/split-by-protein/data', 'train'), transform=PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label'))
    dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
    for active, inactive in dataloader:
        print(active)
        print(inactive)
        break
    # for item in dataloader:
    #     print(item)
    #     break
Ejemplo n.º 14
0
    # Assume labels/classes are integers (0, 1, 2, ...).
    labels = [item['label'] for item in dataset]
    classes, class_sample_count = np.unique(labels, return_counts=True)
    # Weighted sampler for imbalanced classification (1:1 ratio for each class)
    weight = 1. / class_sample_count
    sample_weights = torch.tensor([weight[t] for t in labels])
    sampler = torch.utils.data.WeightedRandomSampler(weights=sample_weights,
                                                     num_samples=len(dataset),
                                                     replacement=True)
    return sampler


if __name__ == "__main__":
    dataset_path = os.path.join(os.environ['LEP_DATA'], 'val')
    dataset = LMDBDataset(dataset_path,
                          transform=CNN3D_TransformLEP(add_flag=True,
                                                       radius=10.0))
    dataloader = DataLoader(dataset,
                            batch_size=8,
                            sampler=create_balanced_sampler(dataset))

    non_zeros = []
    for item in dataloader:
        print('feature inactive shape:', item['feature_inactive'].shape)
        print('feature active shape:', item['feature_active'].shape)
        print('label:', item['label'])
        print('id:', item['id'])
        non_zeros.append(np.count_nonzero(item['label']))
        break
    for item in dataloader:
        non_zeros.append(np.count_nonzero(item['label']))
Ejemplo n.º 15
0
import numpy as np
import os
import torch
from atom3d.util.transforms import prot_graph_transform
from atom3d.datasets import LMDBDataset
from torch_geometric.data import Data, Dataset, DataLoader

    
class GNNTransformPSR(object):
    def __init__(self):
        pass
    def __call__(self, item):
        item = prot_graph_transform(item, ['atoms'], 'scores')
        graph = item['atoms']
        graph.y = torch.FloatTensor([graph.y['gdt_ts']])
        graph.target = item['id'][0]
        graph.decoy = item['id'][1]
        return graph
    

        
if __name__=="__main__":
    dataset = LMDBDataset('/scratch/users/raphtown/atom3d_mirror/lmdb/PSR/splits/split-by-year/data/train', transform=GNNTransformPSR())
    dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
    # for item in dataset[0]:
    #     print(item, type(dataset[0][item]))
    for item in dataloader:
        print(item)
        break
Ejemplo n.º 16
0
import os
import torch
from atom3d.util.transforms import prot_graph_transform
from atom3d.datasets import LMDBDataset
from torch_geometric.data import Data, Dataset, DataLoader


class GNNTransformRSR(object):
    def __init__(self):
        pass

    def __call__(self, item):
        item = prot_graph_transform(item, ['atoms'], 'scores')
        graph = item['atoms']
        graph.y = torch.FloatTensor([graph.y['rms']])
        graph.target = item['id'][0]
        graph.decoy = item['id'][1]
        return graph


if __name__ == "__main__":
    dataset = LMDBDataset(
        '/scratch/users/raphtown/atom3d_mirror/lmdb/RSR/splits/candidates-split-by-time/data/train',
        transform=GNNTransformRSR())
    dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
    # for item in dataset[0]:
    #     print(item, type(dataset[0][item]))
    for item in dataloader:
        print(item)
        break
Ejemplo n.º 17
0
def train(args, device, test_mode=False):
    print("Training model with config:")
    print(str(json.dumps(args.__dict__, indent=4)) + "\n")

    # Save config
    with open(os.path.join(args.output_dir, 'config.json'), 'w') as f:
        json.dump(args.__dict__, f, indent=4)

    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)

    train_dataset = LMDBDataset(
        os.path.join(args.data_dir, 'train'),
        transform=CNN3D_TransformMSP(args.add_flag, args.center_at_mut, random_seed=args.random_seed))
    val_dataset = LMDBDataset(
        os.path.join(args.data_dir, 'val'),
        transform=CNN3D_TransformMSP(args.add_flag, args.center_at_mut, random_seed=args.random_seed))
    test_dataset = LMDBDataset(
        os.path.join(args.data_dir, 'test'),
        transform=CNN3D_TransformMSP(args.add_flag, args.center_at_mut, random_seed=args.random_seed))

    train_loader = DataLoader(train_dataset, args.batch_size,
                              sampler=create_balanced_sampler(train_dataset))
    val_loader = DataLoader(val_dataset, args.batch_size,
                            sampler=create_balanced_sampler(val_dataset))
    test_loader = DataLoader(test_dataset, args.batch_size, shuffle=False)

    for data in train_loader:
        in_channels, spatial_size = data['feature_original'].size()[1:3]
        print('num channels: {:}, spatial size: {:}'.format(in_channels, spatial_size))
        break

    model = conv_model(in_channels, spatial_size, args)
    print(model)
    model.to(device)

    prev_val_loss = np.Inf
    best_val_loss = np.Inf
    best_val_auroc = 0
    best_stats = None

    criterion = nn.BCELoss()
    criterion.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs+1):
        start = time.time()
        train_loss = train_loop(model, train_loader, args.repeat_gen, criterion,
                                optimizer, device)
        val_loss, stats, val_df = test(model, val_loader, args.repeat_gen,
                                       criterion, device)
        elapsed = (time.time() - start)
        print(f'Epoch {epoch:03d} finished in : {elapsed:.3f} s')
        print(f"\tTrain loss {train_loss:.4f}, Val loss: {val_loss:.4f}, "
              f"Val AUROC: {stats['auroc']:.4f}, Val AUPRC: {stats['auroc']:.4f}")
        #if stats['auroc'] > best_val_auroc:
        if val_loss < best_val_loss:
            print(f"\nSave model at epoch {epoch:03d}, val_loss: {val_loss:.4f}, "
                  f"auroc: {stats['auroc']:.4f}, auprc: {stats['auroc']:.4f}")
            save_weights(model, os.path.join(args.output_dir, f'best_weights.pt'))
            best_val_loss = val_loss
            best_val_auroc = stats['auroc']
            best_stats = stats
        if args.early_stopping and val_loss >= prev_val_loss:
            print(f"Validation loss stopped decreasing, stopping at epoch {epoch:03d}...")
            break
        prev_val_loss = val_loss

    if test_mode:
        model.load_state_dict(torch.load(os.path.join(args.output_dir, f'best_weights.pt')))
        test_loss, stats, test_df = test(model, test_loader, args.repeat_gen,
                                         criterion, device)
        test_df.to_pickle(os.path.join(args.output_dir, 'test_results.pkl'))
        print(f"Test loss: {test_loss:.4f}, Test AUROC: {stats['auroc']:.4f}, "
              f"Test AUPRC: {stats['auprc']:.4f}")
        test_file = os.path.join(args.output_dir, f'test_results.txt')
        with open(test_file, 'w') as f:
            f.write(f"test_loss\tAUROC\tAUPRC\n")
            f.write(f"{test_loss:}\t{stats['auroc']:}\t{stats['auprc']:}\n")
Ejemplo n.º 18
0
        pass
    def __call__(self, data_list):
        batch_1 = Batch.from_data_list([d[0] for d in data_list])
        batch_2 = Batch.from_data_list([d[1] for d in data_list])
        return batch_1, batch_2
    

        
if __name__=="__main__":
    save_dir = '/scratch/users/aderry/atom3d/lep'
    data_dir = '/scratch/users/raphtown/atom3d_mirror/lmdb/LEP/splits/split-by-protein/data'
    os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True)
    transform = PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label')
    train_dataset = LMDBDataset(os.path.join(data_dir, 'train'), transform=transform)
    val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=transform)
    test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=transform)
    
    # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4)
    # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4)
    # for item in dataset[0]:
    #     print(item, type(dataset[0][item]))
    for i, item in enumerate(tqdm(train_dataset)):
        torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt'))
    
    for i, item in enumerate(tqdm(val_dataset)):
        torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt'))
    
    for i, item in enumerate(tqdm(test_dataset)):
Ejemplo n.º 19
0
def train(args, device, log_dir, rep=None, test_mode=False):
    # logger = logging.getLogger('lba')
    # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)

    train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                transform=GNNTransformRSR())
    val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                              transform=GNNTransformRSR())
    test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                               transform=GNNTransformRSR())

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4)
    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=4)

    for data in train_loader:
        num_features = data.num_features
        break

    model = GNN_RSR(num_features, hidden_dim=args.hidden_dim).to(device)
    model.to(device)

    best_val_loss = 999
    best_rs = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        val_loss, corrs, results_df = test(model, val_loader, device)
        if corrs['all_spearman'] > best_rs:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                }, os.path.join(log_dir, f'best_weights.pt'))
            best_rs = corrs['all_spearman']
        elapsed = (time.time() - start)
        print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed))
        print(
            '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'
            .format(train_loss, val_loss, corrs['per_target_spearman'],
                    corrs['all_spearman']))

    if test_mode:
        test_file = os.path.join(log_dir, f'rsr_rep{rep}.csv')
        model.load_state_dict(
            torch.load(os.path.join(log_dir, f'best_weights.pt')))
        val_loss, corrs, results_df = test(model, test_loader, device)
        # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png'))
        print(
            '\tTest RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'
            .format(train_loss, val_loss, corrs['per_target_spearman'],
                    corrs['all_spearman']))
        pd.to_csv(results_df, test_file, index=False)
Ejemplo n.º 20
0
 def __init__(self, lmdb_path):
     self.dataset = LMDBDataset(lmdb_path)
Ejemplo n.º 21
0
def train(args, device, log_dir, seed=None, test_mode=False):
    # logger = logging.getLogger('lba')
    # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)

    train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                transform=GNNTransformPSR())
    val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                              transform=GNNTransformPSR())
    test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                               transform=GNNTransformPSR())

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4)
    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=4)

    for data in train_loader:
        num_features = data.num_features
        break

    model = GNN_PSR(num_features, hidden_dim=args.hidden_dim).to(device)
    model.to(device)

    best_val_loss = 999
    best_rp = 0
    best_rs = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.7,
                                                           patience=3,
                                                           min_lr=0.00001)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        print('validating...')
        val_loss, corrs, test_df = test(model, val_loader, device)
        scheduler.step(val_loss)
        if corrs['all_spearman'] > best_rs:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                }, os.path.join(log_dir, f'best_weights.pt'))
            best_rs = corrs['all_spearman']
        elapsed = (time.time() - start)
        print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed))
        print(
            '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'
            .format(train_loss, val_loss, corrs['per_target_spearman'],
                    corrs['all_spearman']))

    if test_mode:
        test_file = os.path.join(log_dir, f'test_results.txt')
        model.load_state_dict(
            torch.load(os.path.join(log_dir, f'best_weights.pt')))
        test_loss, corrs, test_df = test(model, val_loader, device)
        print(
            'Test RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'
            .format(test_loss, corrs['per_target_spearman'],
                    corrs['all_spearman']))
        test_df.to_csv(test_file)
Ejemplo n.º 22
0
        # Last dimension is atom channel, so we need to move it to the front
        # per pytroch style
        grid = np.moveaxis(grid, -1, 0)
        return grid

    def __call__(self, item):
        # Transform protein into voxel grids.
        # Apply random rotation matrix.
        id = eval(item['id'])
        transformed = {
            'feature': self._voxelize(item['atoms']),
            'label': item['scores']['gdt_ts'],
            'target': id[0],
            'decoy': id[1],
        }
        return transformed


if __name__ == "__main__":
    dataset_path = os.path.join(os.environ['PSR_DATA'], 'val')
    dataset = LMDBDataset(dataset_path,
                          transform=CNN3D_TransformPSR(radius=10.0))
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

    for item in dataloader:
        print('feature shape:', item['feature'].shape)
        print('label:', item['label'])
        print('target:', item['target'])
        print('decoy:', item['decoy'])
        break
Ejemplo n.º 23
0
def train(args, device, log_dir, seed=None, test_mode=False):
    # logger = logging.getLogger('lba')
    # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)

    train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                transform=GNNTransformLBA())
    val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                              transform=GNNTransformLBA())
    test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                               transform=GNNTransformLBA())

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4)
    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=4)

    for data in train_loader:
        num_features = data.num_features
        break

    model = GNN_LBA(num_features, hidden_dim=args.hidden_dim).to(device)
    model.to(device)

    best_val_loss = 999
    best_rp = 0
    best_rs = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        val_loss, r_p, r_s, y_true, y_pred = test(model, val_loader, device)
        if val_loss < best_val_loss:
            save_weights(model, os.path.join(log_dir, f'best_weights.pt'))
            # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}.png'))
            best_val_loss = val_loss
            best_rp = r_p
            best_rs = r_s
        elapsed = (time.time() - start)
        print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed))
        print(
            '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'
            .format(train_loss, val_loss, r_p, r_s))
        # logger.info('{:03d}\t{:.7f}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(epoch, train_loss, val_loss, r_p, r_s))

    if test_mode:
        test_file = os.path.join(log_dir, f'test_results.txt')
        model.load_state_dict(
            torch.load(os.path.join(log_dir, f'best_weights.pt')))
        rmse, pearson, spearman, y_true, y_pred = test(model, test_loader,
                                                       device)
        # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png'))
        print(
            'Test RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'.format(
                rmse, pearson, spearman))
        with open(test_file, 'a+') as out:
            out.write('{}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(
                seed, rmse, pearson, spearman))

    return best_val_loss, best_rp, best_rs
Ejemplo n.º 24
0
                'h298_atom',
                'g298_atom',
                'cv_atom',
            ]
            self.label_mapping = {k: v for v, k in enumerate(label_mapping)}
        return item['labels'][self.label_mapping[name]]

    def __call__(self, item):
        # Transform molecule into voxel grids.
        # Apply random rotation matrix.
        transformed = {
            'feature': self._voxelize(item['atoms']),
            'label': self._lookup_label(item, self.label_name),
            'id': item['id'],
        }
        return transformed


if __name__ == "__main__":
    dataset_path = os.path.join(os.environ['SMP_DATA'], 'val')
    dataset = LMDBDataset(dataset_path,
                          transform=CNN3D_TransformSMP(label_name='alpha',
                                                       radius=10.0))
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

    for item in dataloader:
        print('feature shape:', item['feature'].shape)
        print('label:', item['label'])
        print('id:', item['id'])
        break
Ejemplo n.º 25
0
def train(args, device, log_dir, rep=None, test_mode=False):
    # logger = logging.getLogger('lba')
    # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)
    transform = GNNTransformMSP()
    if args.precomputed:
        train_dataset = PTGDataset(os.path.join(args.data_dir, 'train'))
        val_dataset = PTGDataset(os.path.join(args.data_dir, 'val'))
        test_dataset = PTGDataset(os.path.join(args.data_dir, 'val'))
    else:
        train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                    transform=transform)
        val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                                  transform=transform)
        test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                                   transform=transform)

    train_loader = DataLoader(
        train_dataset,
        args.batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=CollaterMSP(batch_size=args.batch_size))
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4,
                            collate_fn=CollaterMSP(batch_size=args.batch_size))
    test_loader = DataLoader(
        test_dataset,
        args.batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=CollaterMSP(batch_size=args.batch_size))

    for original, mutated in train_loader:
        num_features = original.num_features
        break

    gcn_model = GNN_MSP(num_features, hidden_dim=args.hidden_dim).to(device)
    gcn_model.to(device)
    ff_model = MLP_MSP(args.hidden_dim).to(device)

    best_val_loss = 999
    best_val_auroc = 0

    params = [x for x in gcn_model.parameters()
              ] + [x for x in ff_model.parameters()]

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
    criterion.to(device)
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(epoch, gcn_model, ff_model, train_loader,
                                criterion, optimizer, device)
        print('validating...')
        val_loss, auroc, auprc, _, _ = test(gcn_model, ff_model, val_loader,
                                            criterion, device)
        if auroc > best_val_auroc:
            torch.save(
                {
                    'epoch': epoch,
                    'gcn_state_dict': gcn_model.state_dict(),
                    'ff_state_dict': ff_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                }, os.path.join(log_dir, f'best_weights_rep{rep}.pt'))
            best_val_auroc = auroc
        elapsed = (time.time() - start)
        print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed))
        print(
            f'\tTrain loss {train_loss}, Val loss {val_loss}, Val AUROC {auroc}, Val auprc {auprc}'
        )

    if test_mode:
        train_file = os.path.join(log_dir, f'msp-rep{rep}.best.train.pt')
        val_file = os.path.join(log_dir, f'msp-rep{rep}.best.val.pt')
        test_file = os.path.join(log_dir, f'msp-rep{rep}.best.test.pt')
        cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt'))
        gcn_model.load_state_dict(cpt['gcn_state_dict'])
        ff_model.load_state_dict(cpt['ff_state_dict'])
        _, _, _, y_true_train, y_pred_train = test(gcn_model, ff_model,
                                                   train_loader, criterion,
                                                   device)
        torch.save({
            'targets': y_true_train,
            'predictions': y_pred_train
        }, train_file)
        _, _, _, y_true_val, y_pred_val = test(gcn_model, ff_model, val_loader,
                                               criterion, device)
        torch.save({
            'targets': y_true_val,
            'predictions': y_pred_val
        }, val_file)
        test_loss, auroc, auprc, y_true_test, y_pred_test = test(
            gcn_model, ff_model, test_loader, criterion, device)
        print(
            f'\tTest loss {test_loss}, Test AUROC {auroc}, Test auprc {auprc}')
        torch.save({
            'targets': y_true_test,
            'predictions': y_pred_test
        }, test_file)
        return test_loss, auroc, auprc

    return best_val_loss
Ejemplo n.º 26
0
        graph = item['atoms']
        x2 = torch.tensor(item['atom_feats'],
                          dtype=torch.float).t().contiguous()
        graph.x = torch.cat([graph.x.to(torch.float), x2], dim=-1)
        graph.y = self._lookup_label(item, self.label_name)
        graph.id = item['id']
        return graph


if __name__ == "__main__":
    save_dir = '/scratch/users/aderry/atom3d/smp'
    data_dir = '/scratch/users/aderry/lmdb/atom3d/small_molecule_properties/splits/split-randomly/data'
    os.makedirs(os.path.join(save_dir, 'train'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'val'), exist_ok=True)
    os.makedirs(os.path.join(save_dir, 'test'), exist_ok=True)
    train_dataset = LMDBDataset(os.path.join(data_dir, 'train'),
                                transform=GNNTransformSMP(label_name='mu'))
    # val_dataset = LMDBDataset(os.path.join(data_dir, 'val'), transform=GNNTransformSMP())
    # test_dataset = LMDBDataset(os.path.join(data_dir, 'test'), transform=GNNTransformSMP())

    # train_loader = DataLoader(train_dataset, 1, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, 1, shuffle=False, num_workers=4)
    # test_loader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4)
    # for item in dataset[0]:
    #     print(item, type(dataset[0][item]))
    for i, item in enumerate(tqdm(train_dataset)):
        print(item.y)
        # torch.save(item, os.path.join(save_dir, 'train', f'data_{i}.pt'))

    # for i, item in enumerate(tqdm(val_dataset)):
    #     torch.save(item, os.path.join(save_dir, 'val', f'data_{i}.pt'))
Ejemplo n.º 27
0
                                        label_key='scores')
        else:
            item = prot_graph_transform(
                item,
                atom_keys=['atoms_protein', 'atoms_pocket'],
                label_key='scores')
        # transform ligand into PTG graph
        item = mol_graph_transform(item,
                                   'atoms_ligand',
                                   'scores',
                                   use_bonds=True)
        node_feats, edges, edge_feats, node_pos = gr.combine_graphs(
            item['atoms_pocket'], item['atoms_ligand'])
        combined_graph = Data(node_feats,
                              edges,
                              edge_feats,
                              y=item['scores']['neglog_aff'],
                              pos=node_pos)
        return combined_graph


if __name__ == "__main__":
    dataset = LMDBDataset(
        '/scratch/users/aderry/lmdb/atom3d/lba_lmdb/splits/split-by-sequence-identity-30/data/train',
        transform=GNNTransformLBA())
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
    # for item in dataset[0]:
    #     print(item, type(dataset[0][item]))
    for item in dataloader:
        print(item)
        break
Ejemplo n.º 28
0
def train(args, device, log_dir, rep=None, test_mode=False):
    # logger = logging.getLogger('lba')
    # logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)

    if args.precomputed:
        train_dataset = PTGDataset(os.path.join(args.data_dir, 'train'))
        val_dataset = PTGDataset(os.path.join(args.data_dir, 'val'))
        test_dataset = PTGDataset(os.path.join(args.data_dir, 'val'))
    else:
        transform = GNNTransformLBA()
        train_dataset = LMDBDataset(os.path.join(args.data_dir, 'train'),
                                    transform=transform)
        val_dataset = LMDBDataset(os.path.join(args.data_dir, 'val'),
                                  transform=transform)
        test_dataset = LMDBDataset(os.path.join(args.data_dir, 'test'),
                                   transform=transform)

    train_loader = DataLoader(train_dataset,
                              args.batch_size,
                              shuffle=True,
                              num_workers=4)
    val_loader = DataLoader(val_dataset,
                            args.batch_size,
                            shuffle=False,
                            num_workers=4)
    test_loader = DataLoader(test_dataset,
                             args.batch_size,
                             shuffle=False,
                             num_workers=4)

    for data in train_loader:
        num_features = data.num_features
        break

    model = GNN_LBA(num_features, hidden_dim=args.hidden_dim).to(device)
    model.to(device)

    best_val_loss = 999
    best_rp = 0
    best_rs = 0

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(1, args.num_epochs + 1):
        start = time.time()
        train_loss = train_loop(model, train_loader, optimizer, device)
        val_loss, r_p, r_s, y_true, y_pred = test(model, val_loader, device)
        if val_loss < best_val_loss:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': train_loss,
                }, os.path.join(log_dir, f'best_weights_rep{rep}.pt'))
            # plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}.png'))
            best_val_loss = val_loss
            best_rp = r_p
            best_rs = r_s
        elapsed = (time.time() - start)
        print('Epoch: {:03d}, Time: {:.3f} s'.format(epoch, elapsed))
        print(
            '\tTrain RMSE: {:.7f}, Val RMSE: {:.7f}, Pearson R: {:.7f}, Spearman R: {:.7f}'
            .format(train_loss, val_loss, r_p, r_s))
        # logger.info('{:03d}\t{:.7f}\t{:.7f}\t{:.7f}\t{:.7f}\n'.format(epoch, train_loss, val_loss, r_p, r_s))

    if test_mode:
        train_file = os.path.join(log_dir, f'lba-rep{rep}.best.train.pt')
        val_file = os.path.join(log_dir, f'lba-rep{rep}.best.val.pt')
        test_file = os.path.join(log_dir, f'lba-rep{rep}.best.test.pt')
        cpt = torch.load(os.path.join(log_dir, f'best_weights_rep{rep}.pt'))
        model.load_state_dict(cpt['model_state_dict'])
        _, _, _, y_true_train, y_pred_train = test(model, train_loader, device)
        torch.save({
            'targets': y_true_train,
            'predictions': y_pred_train
        }, train_file)
        _, _, _, y_true_val, y_pred_val = test(model, val_loader, device)
        torch.save({
            'targets': y_true_val,
            'predictions': y_pred_val
        }, val_file)
        rmse, pearson, spearman, y_true_test, y_pred_test = test(
            model, test_loader, device)
        print(
            f'\tTest RMSE {rmse}, Test Pearson {pearson}, Test Spearman {spearman}'
        )
        torch.save({
            'targets': y_true_test,
            'predictions': y_pred_test
        }, test_file)

    return best_val_loss, best_rp, best_rs