Beispiel #1
0
def mol_to_graph(mols: list, canonical: bool = False) -> dgl.DGLGraph:
    if canonical:
        graph = [
            mol_to_bigraph(
                mol,
                node_featurizer=CanonicalAtomFeaturizer(),
                edge_featurizer=CanonicalBondFeaturizer(),
            )
            for mol in mols
        ]
    else:
        graph = [mol_to_bigraph(m, node_featurizer=MyNodeFeaturizer()) for m in mols]

    return graph
Beispiel #2
0
    def __FeaturizerSimple(self, mols) -> list:
        atom_featurizer = BaseAtomFeaturizer({
            "n_feat":
            ConcatFeaturizer(
                [
                    # partial(atom_type_one_hot,
                    #        allowable_set=['C', 'N', 'O', 'F', 'Si', 'S'],
                    #        encode_unknown=True),
                    # partial(atom_degree_one_hot, allowable_set=list(range(6))),
                    atom_is_aromatic,
                    atom_formal_charge,
                    atom_num_radical_electrons,
                    partial(atom_hybridization_one_hot, encode_unknown=True),
                    lambda atom: [0
                                  ],  # A placeholder for aromatic information,
                    atom_total_num_H_one_hot,
                ], )
        })
        bond_featurizer = BaseBondFeaturizer(
            {"e_feat": ConcatFeaturizer([bond_type_one_hot, bond_is_in_ring])})

        train_graph = [
            mol_to_bigraph(mol,
                           node_featurizer=atom_featurizer,
                           edge_featurizer=bond_featurizer) for mol in mols
        ]
        return train_graph
Beispiel #3
0
 def __Featurizer(self, train_mols) -> list:
     atom_featurizer = BaseAtomFeaturizer({
         "n_feat":
         ConcatFeaturizer(
             [
                 partial(
                     atom_type_one_hot,
                     allowable_set=["C", "N", "O", "F", "Si", "P", "S"],
                     encode_unknown=True,
                 ),
                 partial(atom_degree_one_hot, allowable_set=list(range(6))),
                 atom_is_aromatic,
                 atom_formal_charge,
                 atom_num_radical_electrons,
                 partial(atom_hybridization_one_hot, encode_unknown=True),
                 atom_implicit_valence,
                 lambda atom: [0
                               ],  # A placeholder for aromatic information,
                 atom_total_num_H_one_hot,
             ], )
     })
     bond_featurizer = BaseBondFeaturizer(
         {"e_feat": ConcatFeaturizer([bond_type_one_hot, bond_is_in_ring])})
     afp_train_graph = [
         mol_to_bigraph(mol,
                        node_featurizer=atom_featurizer,
                        edge_featurizer=bond_featurizer)
         for mol in tqdm(train_mols)
     ]
     return afp_train_graph
Beispiel #4
0
def graph_construction_and_featurization(smiles):
    """Construct graphs from SMILES and featurize them

    Parameters
    ----------
    smiles : list of str
        SMILES of molecules for embedding computation

    Returns
    -------
    list of DGLGraph
        List of graphs constructed and featurized
    list of bool
        Indicators for whether the SMILES string can be
        parsed by RDKit
    """
    graphs = []
    success = []
    for smi in smiles:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                success.append(False)
                continue
            g = mol_to_bigraph(mol,
                               add_self_loop=True,
                               node_featurizer=PretrainAtomFeaturizer(),
                               edge_featurizer=PretrainBondFeaturizer(),
                               canonical_atom_order=False)
            graphs.append(g)
            success.append(True)
        except:
            success.append(False)

    return graphs, success
Beispiel #5
0
    def __CanonicalFeatureize(self, train_mols) -> list:
        atom_featurizer = CanonicalAtomFeaturizer("n_feat")
        bond_featurizer = CanonicalBondFeaturizer("e_feat")

        train_graph = [
            mol_to_bigraph(mol,
                           node_featurizer=atom_featurizer,
                           edge_featurizer=bond_featurizer)
            for mol in train_mols
        ]
        return train_graph
Beispiel #6
0
def moonshot():

    from dgllife.utils import mol_to_bigraph, CanonicalAtomFeaturizer
    import pandas as pd
    import os
    df = pd.read_csv(
        os.path.dirname(graca.data.collections.__file__) +
        "/covid_submissions_all_info.csv")
    df = df.dropna(subset=["f_avg_pIC50"])

    from rdkit import Chem
    from rdkit.Chem import MCS

    ds = []
    for idx0, row0 in df.iterrows():
        smiles0 = row0["SMILES"]
        mol0 = Chem.MolFromSmiles(smiles0)
        for idx1, row1 in df.iloc[idx0 + 1:].iterrows():
            smiles1 = row1["SMILES"]
            mol1 = Chem.MolFromSmiles(smiles1)
            res = MCS.FindMCS([mol0, mol1])
            if res.numAtoms > 15:
                ds.append((
                    mol_to_bigraph(mol1,
                                   node_featurizer=CanonicalAtomFeaturizer(
                                       atom_data_field='feat')),
                    mol_to_bigraph(mol0,
                                   node_featurizer=CanonicalAtomFeaturizer(
                                       atom_data_field='feat')),
                    row1["f_avg_pIC50"],
                    row0["f_avg_pIC50"],
                ))

    ds_tr = ds[:500]
    ds_te = ds[500:]

    return ds_tr, ds_te
Beispiel #7
0
    def _make_atom_graph(self,
                         pdb_code=None,
                         pdb_path=None,
                         node_featurizer=None,
                         edge_featurizer=None,
                         graph_type='bigraph'):
        """
        Create atom-level graph from PDB structure

        :param graph_type:
        :param pdb_code:
        :param pdb_path:
        :param node_featurizer:
        :param edge_featurizer:
        :return:
        """

        if node_featurizer is None:
            node_featurizer = CanonicalAtomFeaturizer()
        if edge_featurizer is None:
            edge_featurizer = CanonicalBondFeaturizer()

        # Read in protein as mol
        # if pdb_path:
        if pdb_code:
            pdb_path = self.pdb_dir + pdb_code + '.pdb'
            if not os.path.isfile(pdb_path):
                self._download_pdb(pdb_code)

        assert os.path.isfile(pdb_path)
        mol = MolFromPDBFile(pdb_path)

        # DGL mol to graph
        if graph_type == 'bigraph':
            g = mol_to_bigraph(mol,
                               node_featurizer=node_featurizer,
                               edge_featurizer=edge_featurizer)
        elif graph_type == 'complete':
            g = mol_to_complete_graph(
                mol,
                node_featurizer=node_featurizer,
            )
        elif graph_type == 'k_nn':
            raise NotImplementedError
        print(g)
        return g
Beispiel #8
0
def main():
    """
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set
    """
    mpi_comm = MPI.COMM_WORLD
    mpi_rank = mpi_comm.Get_rank()
    mpi_size = mpi_comm.Get_size()

    df = pd.read_csv('data/sars_lip.csv')
    smiles_list = df['smiles']

    my_border_low, my_border_high = return_borders(mpi_rank, len(smiles_list), mpi_size)

    my_smiles = smiles_list[my_border_low:my_border_high]
    my_mols = np.array([Chem.MolFromSmiles(m) for m in my_smiles])

    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    bond_featurizer = CanonicalBondFeaturizer()

    e_feats = bond_featurizer.feat_size('e')
    n_feats = atom_featurizer.feat_size('h')

    my_graphs = np.array([mol_to_bigraph(m, node_featurizer=atom_featurizer,
                                         edge_featurizer=bond_featurizer) for m in my_mols])

    sendcounts = np.array(mpi_comm.gather(len(my_graphs), root=0))

    # my_descs = np.array([generate_descriptors(m) for m in my_smiles])
    # if mpi_rank == 0:
        # descs = np.empty((len(smiles_list), 114), dtype=np.float64)
    # else:
        # descs = None

    # mpi_comm.Gatherv(sendbuf=my_descs, recvbuf=(descs, sendcounts), root=0)
    graphs = mpi_comm.gather(my_graphs, root=0)
    X = graphs[0]
    if mpi_rank==0:
        for graph in graphs:
            X = X.vstack([X,graph])
        # np.save('/rds-d2/user/wjm41/hpc-work/sars_descs.npy', descs)
        np.save('/rds-d2/user/wjm41/hpc-work/sars_graphs.npy', X)
        print('SAVED!')
Beispiel #9
0
def main(path, task, graph_type):
    """
    :param path: str specifying path to dataset.
    :param task: str specifying the task. One of ['e_iso_pi', 'z_iso_pi', 'e_iso_n', 'z_iso_n']
    :param graph_type: str. either 'bigraph' or 'complete'
    """

    data_loader = TaskDataLoader(task, path)
    X, y = data_loader.load_property_data()
    X = [Chem.MolFromSmiles(m) for m in X]

    # Collate Function for Dataloader
    def collate(sample):
        graphs, labels = map(list, zip(*sample))
        batched_graph = dgl.batch(graphs)
        batched_graph.set_n_initializer(dgl.init.zero_initializer)
        batched_graph.set_e_initializer(dgl.init.zero_initializer)
        return batched_graph, torch.tensor(labels)

    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    n_feats = atom_featurizer.feat_size('h')
    print('Number of features: ', n_feats)

    X_full, _, y_full, _ = train_test_split(X,
                                            y,
                                            test_size=0.2,
                                            random_state=30)
    y_full = y_full.reshape(-1, 1)

    #  We standardise the outputs but leave the inputs unchanged

    y_scaler = StandardScaler()
    y_full = torch.Tensor(y_scaler.fit_transform(y_full))

    # Set up cross-validation splits

    n_splits = 5
    kf = KFold(n_splits=n_splits)

    X_train_splits = []
    y_train_splits = []
    X_val_splits = []
    y_val_splits = []

    for train_index, test_index in kf.split(X_full):
        X_train, X_val = np.array(X_full)[train_index], np.array(
            X_full)[test_index]
        y_train, y_val = y_full[train_index], y_full[test_index]
        # Create graphs and labels
        if graph_type == 'complete':
            X_train = [
                mol_to_complete_graph(m, node_featurizer=atom_featurizer)
                for m in X_train
            ]
            X_val = [
                mol_to_complete_graph(m, node_featurizer=atom_featurizer)
                for m in X_val
            ]
        elif graph_type == 'bigraph':
            X_train = [
                mol_to_bigraph(m, node_featurizer=atom_featurizer) for m in X
            ]
            X_val = [
                mol_to_bigraph(m, node_featurizer=atom_featurizer)
                for m in X_val
            ]
        X_train_splits.append(X_train)
        X_val_splits.append(X_val)
        y_train_splits.append(y_train)
        y_val_splits.append(y_val)

    def lognuniform(low=1, high=5, size=None, base=10):
        return np.power(base, -np.random.uniform(low, high, size))

    best_rmse = 100000000

    for i in range(1000):

        num_layers = np.random.randint(1, 4)
        classifier_hidden_feats = np.random.randint(1, 128)
        hidden_feats = [np.random.choice([16, 32, 64])] * num_layers
        dropout = [np.random.uniform(0, 0.5)] * num_layers
        batchnorm = [np.random.choice([True, False])] * num_layers
        learning_rate = lognuniform()

        param_set = {
            'num_layers': num_layers,
            'classifier_hidden_feats': classifier_hidden_feats,
            'hidden_feats': hidden_feats,
            'dropout': dropout,
            'batchnorm': batchnorm,
            'lr': learning_rate
        }

        print(f'\nParameter set in trial {i} is \n')
        print(param_set)
        print('\n')

        cv_rmse_list = []

        for j in range(n_splits):

            X_train = X_train_splits[j]
            y_train = y_train_splits[j]
            X_val = X_val_splits[j]
            y_val = y_val_splits[j]

            train_data = list(zip(X_train, y_train))
            test_data = list(zip(X_val, y_val))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)
            test_loader = DataLoader(test_data,
                                     batch_size=32,
                                     shuffle=False,
                                     collate_fn=collate,
                                     drop_last=False)

            gcn_net = GCNPredictor(
                in_feats=n_feats,
                hidden_feats=hidden_feats,
                batchnorm=batchnorm,
                dropout=dropout,
                classifier_hidden_feats=classifier_hidden_feats,
            )
            gcn_net.to(device)

            loss_fn = MSELoss()
            optimizer = torch.optim.Adam(gcn_net.parameters(),
                                         lr=learning_rate)

            gcn_net.train()

            epoch_losses = []
            epoch_rmses = []
            for epoch in range(1, 501):
                epoch_loss = 0
                preds = []
                labs = []
                for i, (bg, labels) in enumerate(train_loader):
                    labels = labels.to(device)
                    atom_feats = bg.ndata.pop('h').to(device)
                    atom_feats, labels = atom_feats.to(device), labels.to(
                        device)
                    y_pred = gcn_net(bg, atom_feats)
                    labels = labels.unsqueeze(dim=1)
                    loss = loss_fn(y_pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.detach().item()

                    # Inverse transform to get RMSE
                    labels = y_scaler.inverse_transform(labels.reshape(-1, 1))
                    y_pred = y_scaler.inverse_transform(
                        y_pred.detach().numpy().reshape(-1, 1))
                    # store labels and preds
                    preds.append(y_pred)
                    labs.append(labels)

                labs = np.concatenate(labs, axis=None)
                preds = np.concatenate(preds, axis=None)
                pearson, p = pearsonr(preds, labs)
                mae = mean_absolute_error(preds, labs)
                rmse = np.sqrt(mean_squared_error(preds, labs))
                r2 = r2_score(preds, labs)

                epoch_loss /= (i + 1)
                if epoch % 20 == 0:
                    print(f"epoch: {epoch}, "
                          f"LOSS: {epoch_loss:.3f}, "
                          f"RMSE: {rmse:.3f}, "
                          f"MAE: {mae:.3f}, "
                          f"R: {pearson:.3f}, "
                          f"R2: {r2:.3f}")
                epoch_losses.append(epoch_loss)
                epoch_rmses.append(rmse)

            # Evaluate
            gcn_net.eval()
            preds = []
            labs = []
            for i, (bg, labels) in enumerate(test_loader):
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                atom_feats, labels = atom_feats.to(device), labels.to(device)
                y_pred = gcn_net(bg, atom_feats)
                labels = labels.unsqueeze(dim=1)

                # Inverse transform to get RMSE
                labels = y_scaler.inverse_transform(labels.reshape(-1, 1))
                y_pred = y_scaler.inverse_transform(
                    y_pred.detach().numpy().reshape(-1, 1))

                preds.append(y_pred)
                labs.append(labels)

            preds = np.concatenate(preds, axis=None)
            labs = np.concatenate(labs, axis=None)

            pearson, p = pearsonr(preds, labs)
            mae = mean_absolute_error(preds, labs)
            rmse = np.sqrt(mean_squared_error(preds, labs))
            cv_rmse_list.append(rmse)
            r2 = r2_score(preds, labs)

            print(
                f'Test RMSE: {rmse:.3f}, MAE: {mae:.3f}, R: {pearson:.3f}, R2: {r2:.3f}'
            )

        param_rmse = np.mean(cv_rmse_list)
        if param_rmse < best_rmse:
            best_rmse = param_rmse
            best_params = param_set

    print('Best RMSE and best params \n')
    print(best_rmse)
    print(best_params)
    np.savetxt('saved_hypers/GCN', best_params)
def main(path, task, n_trials, test_set_size):
    """
    :param path: str specifying path to dataset.
    :param task: str specifying the task. One of ['e_iso_pi', 'z_iso_pi', 'e_iso_n', 'z_iso_n']
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set.
    """

    data_loader = TaskDataLoader(task, path)
    smiles_list, y = data_loader.load_property_data()
    X = [Chem.MolFromSmiles(m) for m in smiles_list]

    # Collate Function for Dataloader
    def collate(sample):
        graphs, labels = map(list, zip(*sample))
        batched_graph = dgl.batch(graphs)
        batched_graph.set_n_initializer(dgl.init.zero_initializer)
        batched_graph.set_e_initializer(dgl.init.zero_initializer)
        return batched_graph, torch.tensor(labels)

    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    bond_featurizer = CanonicalBondFeaturizer()

    e_feats = bond_featurizer.feat_size('e')
    n_feats = atom_featurizer.feat_size('h')
    print('Number of features: ', n_feats)

    X = [
        mol_to_bigraph(m,
                       node_featurizer=atom_featurizer,
                       edge_featurizer=bond_featurizer) for m in X
    ]

    r2_list = []
    rmse_list = []
    mae_list = []
    skipped_trials = 0

    for i in range(0, n_trials):

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_set_size, random_state=i + 5)

        y_train = y_train.reshape(-1, 1)
        y_test = y_test.reshape(-1, 1)

        #  We standardise the outputs but leave the inputs unchanged

        y_scaler = StandardScaler()
        y_train_scaled = torch.Tensor(y_scaler.fit_transform(y_train))
        y_test_scaled = torch.Tensor(y_scaler.transform(y_test))

        train_data = list(zip(X_train, y_train_scaled))
        test_data = list(zip(X_test, y_test_scaled))

        train_loader = DataLoader(train_data,
                                  batch_size=32,
                                  shuffle=True,
                                  collate_fn=collate,
                                  drop_last=False)
        test_loader = DataLoader(test_data,
                                 batch_size=32,
                                 shuffle=False,
                                 collate_fn=collate,
                                 drop_last=False)

        gat_net = GATPredictor(in_feats=n_feats)

        gat_net.to(device)

        loss_fn = MSELoss()
        optimizer = torch.optim.Adam(gat_net.parameters(), lr=0.001)

        gat_net.train()

        epoch_losses = []
        epoch_rmses = []
        for epoch in range(1, 201):
            epoch_loss = 0
            preds = []
            labs = []
            for i, (bg, labels) in enumerate(train_loader):
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                bond_feats = bg.edata.pop('e').to(device)
                atom_feats, bond_feats, labels = atom_feats.to(
                    device), bond_feats.to(device), labels.to(device)
                y_pred = gat_net(bg, atom_feats)
                labels = labels.unsqueeze(dim=1)
                loss = loss_fn(y_pred, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.detach().item()

                # Inverse transform to get RMSE
                labels = y_scaler.inverse_transform(labels.reshape(-1, 1))
                y_pred = y_scaler.inverse_transform(
                    y_pred.detach().numpy().reshape(-1, 1))

                # store labels and preds
                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=None)
            preds = np.concatenate(preds, axis=None)
            pearson, p = pearsonr(preds, labs)
            mae = mean_absolute_error(preds, labs)
            rmse = np.sqrt(mean_squared_error(preds, labs))
            r2 = r2_score(preds, labs)

            epoch_loss /= (i + 1)
            if epoch % 20 == 0:
                print(f"epoch: {epoch}, "
                      f"LOSS: {epoch_loss:.3f}, "
                      f"RMSE: {rmse:.3f}, "
                      f"MAE: {mae:.3f}, "
                      f"R: {pearson:.3f}, "
                      f"R2: {r2:.3f}")
            epoch_losses.append(epoch_loss)
            epoch_rmses.append(rmse)

        # Discount trial if train RMSE finishes as a negative value (optimiser error).

        if r2 < 0:
            skipped_trials += 1
            print('Skipped trials is {}'.format(skipped_trials))
            continue

        # Evaluate
        gat_net.eval()
        preds = []
        labs = []
        for i, (bg, labels) in enumerate(test_loader):
            labels = labels.to(device)
            atom_feats = bg.ndata.pop('h').to(device)
            bond_feats = bg.edata.pop('e').to(device)
            atom_feats, labels = atom_feats.to(device), labels.to(device)
            y_pred = gat_net(bg, atom_feats)
            labels = labels.unsqueeze(dim=1)

            # Inverse transform to get RMSE
            labels = y_scaler.inverse_transform(labels.reshape(-1, 1))
            y_pred = y_scaler.inverse_transform(
                y_pred.detach().numpy().reshape(-1, 1))

            preds.append(y_pred)
            labs.append(labels)

        labs = np.concatenate(labs, axis=None)
        preds = np.concatenate(preds, axis=None)

        pearson, p = pearsonr(preds, labs)
        mae = mean_absolute_error(preds, labs)
        rmse = np.sqrt(mean_squared_error(preds, labs))
        r2 = r2_score(preds, labs)

        r2_list.append(r2)
        rmse_list.append(rmse)
        mae_list.append(mae)

        print(
            f'Test RMSE: {rmse:.3f}, MAE: {mae:.3f}, R: {pearson:.3f}, R2: {r2:.3f}'
        )

    r2_list = np.array(r2_list)
    rmse_list = np.array(rmse_list)
    mae_list = np.array(mae_list)

    print("\nmean R^2: {:.4f} +- {:.4f}".format(
        np.mean(r2_list),
        np.std(r2_list) / np.sqrt(len(r2_list))))
    print("mean RMSE: {:.4f} +- {:.4f}".format(
        np.mean(rmse_list),
        np.std(rmse_list) / np.sqrt(len(rmse_list))))
    print("mean MAE: {:.4f} +- {:.4f}\n".format(
        np.mean(mae_list),
        np.std(mae_list) / np.sqrt(len(mae_list))))
    print("\nSkipped trials is {}".format(skipped_trials))
Beispiel #11
0
def main(args):
    """
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set
    """

    df = pd.read_csv('data/covid_multitask_pIC50.smi')
    smiles_list = df['SMILES']
    y = df[['acry_class', 'chloro_class', 'rest_class', 'acry_reg', 'chloro_reg', 'rest_reg']].to_numpy()
    n_tasks = y.shape[1]
    class_inds = [0,1,2]
    reg_inds = [3,4,5]
    X = [Chem.MolFromSmiles(m) for m in smiles_list]

    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    bond_featurizer = CanonicalBondFeaturizer()

    e_feats = bond_featurizer.feat_size('e')
    n_feats = atom_featurizer.feat_size('h')
    print('Number of features: ', n_feats)

    X = np.array([mol_to_bigraph(m, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for m in X])

    r2_list = []
    rmse_list = []
    roc_list = []
    prc_list = []

    for i in range(args.n_trials):
        #kf = StratifiedKFold(n_splits=3, random_state=i, shuffle=True)
        #split_list = kf.split(X, y)

        #X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_set_size, random_state=i+5)
        X_train_acry, X_test_acry, \
        y_train_acry, y_test_acry = train_test_split(X[~np.isnan(y[:,0])], y[~np.isnan(y[:,0])],
                                                     stratify=y[:,0][~np.isnan(y[:,0])],
                                                     test_size=args.test_set_size, shuffle=True, random_state=i+5)
        X_train_chloro, X_test_chloro, \
        y_train_chloro, y_test_chloro = train_test_split(X[~np.isnan(y[:,1])], y[~np.isnan(y[:,1])],
                                                         stratify=y[:,1][~np.isnan(y[:,1])],
                                                         test_size=args.test_set_size, shuffle=True, random_state=i+5)
        X_train_rest, X_test_rest, \
        y_train_rest, y_test_rest = train_test_split(X[~np.isnan(y[:,2])], y[~np.isnan(y[:,2])],
                                                     stratify=y[:,2][~np.isnan(y[:,2])],
                                                     test_size=args.test_set_size, shuffle=True, random_state=i+5)

        X_train = np.concatenate([X_train_acry, X_train_chloro, X_train_rest])
        X_test = np.concatenate([X_test_acry, X_test_chloro, X_test_rest])

        y_train = np.concatenate([y_train_acry, y_train_chloro, y_train_rest])
        y_test = np.concatenate([y_test_acry, y_test_chloro, y_test_rest])

        writer = SummaryWriter('runs/multitask_pIC50/run_' + str(i))

        # writer = SummaryWriter('runs/multitask/run_' + str(i) + '_fold_' + str(j))


        y_train = torch.Tensor(y_train)
        y_test = torch.Tensor(y_test)

        train_data = list(zip(X_train, y_train))
        test_data = list(zip(X_test, y_test))

        train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate, drop_last=False)
        test_loader = DataLoader(test_data, batch_size=32, shuffle=True, collate_fn=collate, drop_last=False)

        process = Net(class_inds, reg_inds)
        process.to(device)

        mpnn_net = MPNNPredictor(node_in_feats=n_feats,
                                       edge_in_feats=e_feats,
                                       node_out_feats=128,
                                       n_tasks=n_tasks
                                  )
        mpnn_net.to(device)

        reg_loss_fn = MSELoss()
        class_loss_fn = BCELoss()

        optimizer = torch.optim.Adam(mpnn_net.parameters(), lr=0.001)

        for epoch in range(1, args.n_epochs+1):
            epoch_loss = 0
            preds = []
            labs = []
            mpnn_net.train()
            for i, (bg, labels) in enumerate(train_loader):
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                bond_feats = bg.edata.pop('e').to(device)
                #atom_feats, bond_feats, dcs, labels = atom_feats.to(device), bond_feats.to(device), dcs.to(device), labels.to(device)
                y_pred = mpnn_net(bg, atom_feats, bond_feats)
                y_pred = process(y_pred)
                loss=torch.tensor(0)
                loss = loss.to(device)
                for ind in reg_inds:
                    if len(labels[:,ind][~torch.isnan(labels[:,ind])])==0:
                        continue
                    loss = loss + reg_loss_fn(y_pred[:,ind][~torch.isnan(labels[:,ind])], labels[:,ind][~torch.isnan(labels[:,ind])])
                for ind in class_inds:
                    if len(labels[:,ind][~torch.isnan(labels[:,ind])])==0:
                        continue
                    loss = loss + class_loss_fn(y_pred[:,ind][~torch.isnan(labels[:,ind])], labels[:,ind][~torch.isnan(labels[:,ind])])
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.detach().item()

                labels = labels.cpu().numpy()
                y_pred = y_pred.detach().cpu().numpy()


                # store labels and preds
                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=0)
            preds = np.concatenate(preds, axis=0)
            rmses= []
            r2s = []
            rocs = []
            prcs = []
            for ind in reg_inds:
                rmse = np.sqrt(mean_squared_error(labs[:,ind][~np.isnan(labs[:,ind])],preds[:,ind][~np.isnan(labs[:,ind])]))
                r2 = r2_score(labs[:,ind][~np.isnan(labs[:,ind])],preds[:,ind][~np.isnan(labs[:,ind])])
                rmses.append(rmse)
                r2s.append(r2)

            for ind in class_inds:
                r2 = roc_auc_score(labs[:,ind][~np.isnan(labs[:,ind])],
                                   preds[:,ind][~np.isnan(labs[:,ind])])
                precision, recall, thresholds = precision_recall_curve(labs[:,ind][~np.isnan(labs[:,ind])],
                                                                       preds[:,ind][~np.isnan(labs[:,ind])])
                rmse = auc(recall, precision)
                rocs.append(r2)
                prcs.append(rmse)

            writer.add_scalar('LOSS/train', epoch_loss, epoch)
            writer.add_scalar('train/acry_rocauc', rocs[0], epoch)
            writer.add_scalar('train/acry_prcauc', prcs[0], epoch)
            writer.add_scalar('train/chloro_rocauc', rocs[1], epoch)
            writer.add_scalar('train/chloro_prcauc', prcs[1], epoch)
            writer.add_scalar('train/rest_rocauc', rocs[2], epoch)
            writer.add_scalar('train/rest_prcauc', prcs[2], epoch)

            writer.add_scalar('train/acry_rmse', rmses[0], epoch)
            writer.add_scalar('train/acry_r2', r2s[0], epoch)
            writer.add_scalar('train/chloro_rmse', rmses[1], epoch)
            writer.add_scalar('train/chloro_r2', r2s[1], epoch)
            writer.add_scalar('train/rest_rmse', rmses[2], epoch)
            writer.add_scalar('train/rest_r2', r2s[2], epoch)

            if epoch % 20 == 0:
                print(f"\nepoch: {epoch}, "
                      f"LOSS: {epoch_loss:.3f}"
                      f"\n acry ROC-AUC: {rocs[0]:.3f}, "
                      f"acry PRC-AUC: {prcs[0]:.3f}"
                      f"\n chloro ROC-AUC: {rocs[1]:.3f}, "
                      f"chloro PRC-AUC: {prcs[1]:.3f}"
                      f"\n rest ROC-AUC: {rocs[2]:.3f}, "
                      f"rest PRC-AUC: {prcs[2]:.3f}"
                      f"\n acry R2: {r2s[0]:.3f}, "
                      f"acry RMSE: {rmses[0]:.3f}"
                      f"\n chloro R2: {r2s[1]:.3f}, "
                      f"chloro RMSE: {rmses[1]:.3f}"
                      f"\n rest R2: {r2s[2]:.3f}, "
                      f"rest RMSE: {rmses[2]:.3f}")

            # Evaluate
            mpnn_net.eval()
            preds = []
            labs = []
            for i, (bg, labels) in enumerate(test_loader):
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                bond_feats = bg.edata.pop('e').to(device)
                #atom_feats, bond_feats, labels = atom_feats.to(device), bond_feats.to(device), labels.to(device)
                y_pred = mpnn_net(bg, atom_feats, bond_feats)
                y_pred = process(y_pred)

                labels = labels.cpu().numpy()
                y_pred = y_pred.detach().cpu().numpy()

                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=0)
            preds = np.concatenate(preds, axis=0)
            rmses = []
            r2s = []
            rocs = []
            prcs = []
            for ind in reg_inds:

                rmse = np.sqrt(mean_squared_error(labs[:,ind][~np.isnan(labs[:,ind])],preds[:,ind][~np.isnan(labs[:,ind])]))

                r2 = r2_score(labs[:,ind][~np.isnan(labs[:,ind])],preds[:,ind][~np.isnan(labs[:,ind])])
                rmses.append(rmse)
                r2s.append(r2)
            for ind in class_inds:
                r2 = roc_auc_score(labs[:, ind][~np.isnan(labs[:,ind])],
                                   preds[:, ind][~np.isnan(labs[:,ind])])
                precision, recall, thresholds = precision_recall_curve(labs[:, ind][~np.isnan(labs[:,ind])],
                                                                       preds[:, ind][~np.isnan(labs[:,ind])])
                rmse = auc(recall, precision)
                rocs.append(r2)
                prcs.append(rmse)
            writer.add_scalar('test/acry_rocauc', rocs[0], epoch)
            writer.add_scalar('test/acry_prcauc', prcs[0], epoch)
            writer.add_scalar('test/chloro_rocauc', rocs[1], epoch)
            writer.add_scalar('test/chloro_prcauc', prcs[1], epoch)
            writer.add_scalar('test/rest_rocauc', rocs[2], epoch)
            writer.add_scalar('test/rest_prcauc', prcs[2], epoch)

            writer.add_scalar('test/acry_rmse', rmses[0], epoch)
            writer.add_scalar('test/acry_r2', r2s[0], epoch)
            writer.add_scalar('test/chloro_rmse', rmses[1], epoch)
            writer.add_scalar('test/chloro_r2', r2s[1], epoch)
            writer.add_scalar('test/rest_rmse', rmses[2], epoch)
            writer.add_scalar('test/rest_r2', r2s[2], epoch)
            if epoch==(args.n_epochs):
                print(f"\n======================== TEST ========================"
                      f"\n acry ROC-AUC: {rocs[0]:.3f}, "
                      f"acry PRC-AUC: {prcs[0]:.3f}"
                      f"\n chloro ROC-AUC: {rocs[1]:.3f}, "
                      f"chloro PRC-AUC: {prcs[1]:.3f}"
                      f"\n rest ROC-AUC: {rocs[2]:.3f}, "
                      f"rest PRC-AUC: {prcs[2]:.3f}"
                      f"\n acry R2: {r2s[0]:.3f}, "
                      f"acry RMSE: {rmses[0]:.3f}"
                      f"\n chloro R2: {r2s[1]:.3f}, "
                      f"chloro RMSE: {rmses[1]:.3f}"
                      f"\n rest R2: {r2s[2]:.3f}, "
                      f"rest RMSE: {rmses[2]:.3f}")
                roc_list.append(rocs)
                prc_list.append(prcs)
                r2_list.append(r2s)
                rmse_list.append(rmses)
    roc_list = np.array(roc_list).T
    prc_list = np.array(prc_list).T
    r2_list = np.array(r2_list).T
    rmse_list = np.array(rmse_list).T
    print("\n ACRY")
    print("R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list[0]), np.std(r2_list[0])/np.sqrt(len(r2_list[0]))))
    print("RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list[0]), np.std(rmse_list[0])/np.sqrt(len(rmse_list[0]))))
    print("ROC-AUC: {:.3f} +- {:.3f}".format(np.mean(roc_list[0]), np.std(roc_list[0]) / np.sqrt(len(roc_list[0]))))
    print("PRC-AUC: {:.3f} +- {:.3f}".format(np.mean(prc_list[0]), np.std(prc_list[0]) / np.sqrt(len(prc_list[0]))))
    print("\n CHLORO")
    print("R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list[1]), np.std(r2_list[1])/np.sqrt(len(r2_list[1]))))
    print("RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list[1]), np.std(rmse_list[1])/np.sqrt(len(rmse_list[1]))))
    print("ROC-AUC: {:.3f} +- {:.3f}".format(np.mean(roc_list[1]), np.std(roc_list[1]) / np.sqrt(len(roc_list[1]))))
    print("PRC-AUC: {:.3f} +- {:.3f}".format(np.mean(prc_list[1]), np.std(prc_list[1]) / np.sqrt(len(prc_list[1]))))
    print("\n REST")
    print("R^2: {:.4f} +- {:.4f}".format(np.mean(r2_list[2]), np.std(r2_list[2])/np.sqrt(len(r2_list[2]))))
    print("RMSE: {:.4f} +- {:.4f}".format(np.mean(rmse_list[2]), np.std(rmse_list[2])/np.sqrt(len(rmse_list[2]))))
    print("ROC-AUC: {:.3f} +- {:.3f}".format(np.mean(roc_list[2]), np.std(roc_list[2]) / np.sqrt(len(roc_list[2]))))
    print("PRC-AUC: {:.3f} +- {:.3f}".format(np.mean(prc_list[2]), np.std(prc_list[2]) / np.sqrt(len(prc_list[2]))))
Beispiel #12
0
def main(args):
    """
    :param path: str specifying path to dataset.
    :param task: str specifying the task. One of ['e_iso_pi', 'z_iso_pi', 'e_iso_n', 'z_iso_n']
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set
    """

    # data_loader = TaskDataLoader(args.task, args.path)
    # smiles_list, y = data_loader.load_property_data()

    smiles_list, y = parse_dataset(args.task, PATHS[args.task], args.reg)
    X = [Chem.MolFromSmiles(m) for m in smiles_list]

    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    bond_featurizer = CanonicalBondFeaturizer()

    e_feats = bond_featurizer.feat_size('e')
    n_feats = atom_featurizer.feat_size('h')
    print('Number of features: ', n_feats)

    X = [
        mol_to_bigraph(m,
                       node_featurizer=atom_featurizer,
                       edge_featurizer=bond_featurizer) for m in X
    ]

    r2_list = []
    rmse_list = []
    mae_list = []
    skipped_trials = 0

    for i in range(args.n_trials):

        # X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=args.test_set_size, random_state=i + 5)

        kf = StratifiedKFold(n_splits=args.n_folds,
                             random_state=i,
                             shuffle=True)
        split_list = kf.split(X, y)
        j = 0
        for train_ind, test_ind in split_list:
            if args.reg:
                writer = SummaryWriter('runs/' + args.task + '/mpnn/reg/run_' +
                                       str(i) + '_fold_' + str(j))
            else:
                writer = SummaryWriter('runs/' + args.task +
                                       '/mpnn/class/run_' + str(i) + '_fold_' +
                                       str(j))
            X_train, X_test = np.array(X)[train_ind], np.array(X)[test_ind]
            y_train, y_test = np.array(y)[train_ind], np.array(y)[test_ind]

            y_train = y_train.reshape(-1, 1)
            y_test = y_test.reshape(-1, 1)

            #  We standardise the outputs but leave the inputs unchanged
            if args.reg:
                y_scaler = StandardScaler()
                y_train_scaled = torch.Tensor(y_scaler.fit_transform(y_train))
                y_test_scaled = torch.Tensor(y_scaler.transform(y_test))
            else:
                y_train_scaled = torch.Tensor(y_train)
                y_test_scaled = torch.Tensor(y_test)

            train_data = list(zip(X_train, y_train_scaled))
            test_data = list(zip(X_test, y_test_scaled))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)
            test_loader = DataLoader(test_data,
                                     batch_size=32,
                                     shuffle=False,
                                     collate_fn=collate,
                                     drop_last=False)

            mpnn_net = MPNNPredictor(node_in_feats=n_feats,
                                     edge_in_feats=e_feats)
            mpnn_net.to(device)

            if args.reg:
                loss_fn = MSELoss()
            else:
                loss_fn = BCELoss()
            optimizer = torch.optim.Adam(mpnn_net.parameters(), lr=1e-4)

            mpnn_net.train()

            epoch_losses = []
            epoch_rmses = []
            for epoch in tqdm(range(1, args.n_epochs)):
                epoch_loss = 0
                preds = []
                labs = []
                for i, (bg, labels) in tqdm(enumerate(train_loader)):
                    labels = labels.to(device)
                    atom_feats = bg.ndata.pop('h').to(device)
                    bond_feats = bg.edata.pop('e').to(device)
                    atom_feats, bond_feats, labels = atom_feats.to(
                        device), bond_feats.to(device), labels.to(device)
                    y_pred = mpnn_net(bg, atom_feats, bond_feats)
                    labels = labels.unsqueeze(dim=1)
                    loss = loss_fn(y_pred, labels)
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    epoch_loss += loss.detach().item()

                    if args.reg:
                        # Inverse transform to get RMSE
                        labels = y_scaler.inverse_transform(
                            labels.cpu().reshape(-1, 1))
                        y_pred = y_scaler.inverse_transform(
                            y_pred.detach().cpu().numpy().reshape(-1, 1))
                    else:
                        labels = labels.cpu().numpy()
                        y_pred = y_pred.detach().cpu().numpy()

                    # store labels and preds
                    preds.append(y_pred)
                    labs.append(labels)

                labs = np.concatenate(labs, axis=None)
                preds = np.concatenate(preds, axis=None)
                pearson, p = pearsonr(preds, labs)
                if args.reg:
                    mae = mean_absolute_error(preds, labs)
                    rmse = np.sqrt(mean_squared_error(preds, labs))
                    r2 = r2_score(preds, labs)
                else:
                    r2 = roc_auc_score(labs, preds)
                    precision, recall, thresholds = precision_recall_curve(
                        labs, preds)
                    rmse = auc(recall, precision)
                    mae = 0

                if args.reg:
                    writer.add_scalar('Loss/train', epoch_loss, epoch)
                    writer.add_scalar('RMSE/train', rmse, epoch)
                    writer.add_scalar('R2/train', r2, epoch)
                else:
                    writer.add_scalar('Loss/train', epoch_loss, epoch)
                    writer.add_scalar('ROC-AUC/train', r2, epoch)
                    writer.add_scalar('PRC-AUC/train', rmse, epoch)

                if epoch % 20 == 0:
                    if args.reg:
                        print(f"epoch: {epoch}, "
                              f"LOSS: {epoch_loss:.3f}, "
                              f"RMSE: {rmse:.3f}, "
                              f"MAE: {mae:.3f}, "
                              f"rho: {pearson:.3f}, "
                              f"R2: {r2:.3f}")

                    else:
                        print(f"epoch: {epoch}, "
                              f"LOSS: {epoch_loss:.3f}, "
                              f"ROC-AUC: {r2:.3f}, "
                              f"PRC-AUC: {rmse:.3f}, "
                              f"rho: {pearson:.3f}")
                epoch_losses.append(epoch_loss)
                epoch_rmses.append(rmse)

            # Discount trial if train RMSE finishes as a negative value (optimiser error).

            if r2 < -1:
                skipped_trials += 1
                print('Skipped trials is {}'.format(skipped_trials))
                continue

            # Evaluate
            mpnn_net.eval()
            preds = []
            labs = []
            for i, (bg, labels) in enumerate(test_loader):
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                bond_feats = bg.edata.pop('e').to(device)
                atom_feats, bond_feats, labels = atom_feats.to(
                    device), bond_feats.to(device), labels.to(device)
                y_pred = mpnn_net(bg, atom_feats, bond_feats)
                labels = labels.unsqueeze(dim=1)

                if args.reg:
                    # Inverse transform to get RMSE
                    labels = y_scaler.inverse_transform(labels.cpu().reshape(
                        -1, 1))
                    y_pred = y_scaler.inverse_transform(
                        y_pred.detach().cpu().numpy().reshape(-1, 1))
                else:
                    labels = labels.cpu().numpy()
                    y_pred = y_pred.detach().cpu().numpy()
                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=None)
            preds = np.concatenate(preds, axis=None)
            pearson, p = pearsonr(preds, labs)
            if args.reg:
                mae = mean_absolute_error(preds, labs)
                rmse = np.sqrt(mean_squared_error(preds, labs))
                r2 = r2_score(preds, labs)
                writer.add_scalar('RMSE/test', rmse)
                writer.add_scalar('R2/test', r2)
                print(
                    f'Test RMSE: {rmse:.3f}, MAE: {mae:.3f}, R: {pearson:.3f}, R2: {r2:.3f}'
                )
            else:
                r2 = roc_auc_score(labs, preds)
                precision, recall, thresholds = precision_recall_curve(
                    labs, preds)
                rmse = auc(recall, precision)
                mae = 0
                writer.add_scalar('ROC-AUC/test', r2)
                writer.add_scalar('PRC-AUC/test', rmse)
                print(
                    f'Test ROC-AUC: {r2:.3f}, PRC-AUC: {rmse:.3f}, rho: {pearson:.3f}'
                )

            r2_list.append(r2)
            rmse_list.append(rmse)
            mae_list.append(mae)
            j += 1

    r2_list = np.array(r2_list)
    rmse_list = np.array(rmse_list)
    mae_list = np.array(mae_list)
    if args.reg:
        print("\nmean R^2: {:.4f} +- {:.4f}".format(
            np.mean(r2_list),
            np.std(r2_list) / np.sqrt(len(r2_list))))
        print("mean RMSE: {:.4f} +- {:.4f}".format(
            np.mean(rmse_list),
            np.std(rmse_list) / np.sqrt(len(rmse_list))))
        print("mean MAE: {:.4f} +- {:.4f}\n".format(
            np.mean(mae_list),
            np.std(mae_list) / np.sqrt(len(mae_list))))
    else:
        print("mean ROC-AUC^2: {:.3f} +- {:.3f}".format(
            np.mean(r2_list),
            np.std(r2_list) / np.sqrt(len(r2_list))))
        print("mean PRC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(rmse_list),
            np.std(rmse_list) / np.sqrt(len(rmse_list))))
    print("\nSkipped trials is {}".format(skipped_trials))
Beispiel #13
0
if torch.cuda.is_available():
    print('use GPU')
    device='cuda'
else:
    print('use CPU')
    device='cpu'

train_mols = [Chem.MolFromSmiles(s) for s in train_smiles ]

atom_featurizer = CanonicalAtomFeaturizer(atom_data_field = 'h')
n_feats = atom_featurizer.feat_size('h')
bond_featurizer = CanonicalBondFeaturizer(bond_data_field='h')
b_feat = bond_featurizer.feat_size('h')

train_graph =[mol_to_bigraph(mol,node_featurizer=atom_featurizer, 
                           edge_featurizer=bond_featurizer) for mol in train_mols]


ncls = 1
model = GCNPredictor(in_feats=n_feats,
                     hidden_feats=[60,20],
                     n_tasks=ncls,
                     predictor_dropout=0.2)
#model = AttentiveFPGNN(n_feats,b_feat,2,200)
model = model.to(device)
print(model)


def collate_molgraphs(data):
    """Batching a list of datapoints for dataloader.
    Parameters
Beispiel #14
0
def name2g(data_path, name):
    path = os.path.join(data_path, name + '.sdf')
    for mol in Chem.SDMolSupplier(path):
        g = mol_to_bigraph(mol, node_featurizer=CanonicalAtomFeaturizer())
    return g
Beispiel #15
0
def main(args):
    """
    :param n_trials: int specifying number of random train/test splits to use
    :param test_set_size: float in range [0, 1] specifying fraction of dataset to use as test set
    """

    df = pd.read_csv('data/covid_multitask_HTS.smi')
    if args.dry:
        df = df[:2000]
    smiles_list = df['SMILES'].values
    y = df[[
        'acry_reg', 'chloro_reg', 'rest_reg', 'acry_class', 'chloro_class',
        'rest_class', 'activity'
    ]].to_numpy()
    n_tasks = y.shape[1]
    reg_inds = [0, 1, 2]
    class_inds = [3, 4, 5, 6]
    # print(smiles_list)
    X = [Chem.MolFromSmiles(m) for m in smiles_list]

    # Initialise featurisers
    atom_featurizer = CanonicalAtomFeaturizer()
    bond_featurizer = CanonicalBondFeaturizer()

    e_feats = bond_featurizer.feat_size('e')
    n_feats = atom_featurizer.feat_size('h')
    print('Number of features: ', n_feats)

    X = np.array([
        mol_to_bigraph(m,
                       node_featurizer=atom_featurizer,
                       edge_featurizer=bond_featurizer) for m in X
    ])

    r2_list = []
    rmse_list = []
    roc_list = []
    prc_list = []

    for i in range(args.n_trials):
        writer = SummaryWriter('runs/' + args.savename)

        if args.test:
            X_train_acry, X_test_acry, \
            y_train_acry, y_test_acry = train_test_split(X[~np.isnan(y[:,3])],
                                                         y[~np.isnan(y[:,3])], stratify=y[:,3][~np.isnan(y[:,3])],
                                                         test_size=args.test_set_size, shuffle=True, random_state=i+5)
            X_train_chloro, X_test_chloro, \
            y_train_chloro, y_test_chloro = train_test_split(X[~np.isnan(y[:,4])],
                                                             y[~np.isnan(y[:,4])], stratify=y[:,4][~np.isnan(y[:,4])],
                                                              test_size=args.test_set_size, shuffle=True, random_state=i+5)
            X_train_rest, X_test_rest, \
            y_train_rest, y_test_rest = train_test_split(X[~np.isnan(y[:,5])],
                                                         y[~np.isnan(y[:,5])], stratify=y[:,5][~np.isnan(y[:,5])],
                                                          test_size=args.test_set_size, shuffle=True, random_state=i+5)

            X_train = np.concatenate([
                X_train_acry, X_train_chloro, X_train_rest,
                X[~np.isnan(y[:, 6])]
            ])
            X_test = np.concatenate([X_test_acry, X_test_chloro, X_test_rest])
            y_train = np.concatenate([
                y_train_acry, y_train_chloro, y_train_rest,
                y[~np.isnan(y[:, 6])]
            ])
            y_test = np.concatenate([y_test_acry, y_test_chloro, y_test_rest])

            y_train = torch.Tensor(y_train)
            y_test = torch.Tensor(y_test)

            train_data = list(zip(X_train, y_train))
            test_data = list(zip(X_test, y_test))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)
            test_loader = DataLoader(test_data,
                                     batch_size=32,
                                     shuffle=True,
                                     collate_fn=collate,
                                     drop_last=False)
        else:
            y = torch.Tensor(y)
            train_data = list(zip(X, y))

            train_loader = DataLoader(train_data,
                                      batch_size=32,
                                      shuffle=True,
                                      collate_fn=collate,
                                      drop_last=False)

        process = Net(class_inds, reg_inds)
        process = process.to(device)

        mpnn_net = MPNNPredictor(node_in_feats=n_feats,
                                 edge_in_feats=e_feats,
                                 node_out_feats=128,
                                 n_tasks=n_tasks)
        mpnn_net = mpnn_net.to(device)
        # try:
        #     mpnn_net.load_state_dict(torch.load('/rds-d2/user/wjm41/hpc-work/models/' + args.savename +
        #                        '/model_epoch_20.pt'))
        # except: pass
        reg_loss_fn = MSELoss()
        class_loss_fn = BCELoss()

        optimizer = torch.optim.Adam(mpnn_net.parameters(), lr=1e-4)

        for epoch in range(1, args.n_epochs + 1):
            epoch_loss = 0
            preds = []
            labs = []
            mpnn_net.train()
            n = 0
            for i, (bg, labels) in enumerate(train_loader):
                labels = labels.to(device)
                atom_feats = bg.ndata.pop('h').to(device)
                bond_feats = bg.edata.pop('e').to(device)
                y_pred = mpnn_net(bg, atom_feats, bond_feats)
                y_pred = process(y_pred)
                loss = torch.tensor(0)
                loss = loss.to(device)

                if args.debug:
                    print('label: {}'.format(labels))
                    print('y_pred: {}'.format(y_pred))
                for ind in reg_inds:
                    if len(labels[:, ind][~torch.isnan(labels[:, ind])]) == 0:
                        continue
                    loss = loss + reg_loss_fn(
                        y_pred[:, ind][~torch.isnan(labels[:, ind])],
                        labels[:, ind][~torch.isnan(labels[:, ind])])
                if args.debug:
                    print('reg loss: {}'.format(loss))
                for ind in class_inds:
                    if len(labels[:, ind][~torch.isnan(labels[:, ind])]) == 0:
                        continue
                    loss = loss + class_loss_fn(
                        y_pred[:, ind][~torch.isnan(labels[:, ind])],
                        labels[:, ind][~torch.isnan(labels[:, ind])])
                if args.debug:
                    print('class + reg loss: {}'.format(loss))
                optimizer.zero_grad()
                loss.backward()

                optimizer.step()
                epoch_loss += loss.detach().item()

                labels = labels.cpu().numpy()
                y_pred = y_pred.detach().cpu().numpy()

                # store labels and preds
                preds.append(y_pred)
                labs.append(labels)

            labs = np.concatenate(labs, axis=0)
            preds = np.concatenate(preds, axis=0)
            rmses = []
            r2s = []
            rocs = []
            prcs = []
            for ind in reg_inds:
                rmse = np.sqrt(
                    mean_squared_error(labs[:, ind][~np.isnan(labs[:, ind])],
                                       preds[:, ind][~np.isnan(labs[:, ind])]))
                r2 = r2_score(labs[:, ind][~np.isnan(labs[:, ind])],
                              preds[:, ind][~np.isnan(labs[:, ind])])
                rmses.append(rmse)
                r2s.append(r2)

            for ind in class_inds:
                roc = roc_auc_score(labs[:, ind][~np.isnan(labs[:, ind])],
                                    preds[:, ind][~np.isnan(labs[:, ind])])
                precision, recall, thresholds = precision_recall_curve(
                    labs[:, ind][~np.isnan(labs[:, ind])],
                    preds[:, ind][~np.isnan(labs[:, ind])])
                prc = auc(recall, precision)
                rocs.append(roc)
                prcs.append(prc)

            writer.add_scalar('LOSS/train', epoch_loss, epoch)
            writer.add_scalar('train/acry_rocauc', rocs[0], epoch)
            writer.add_scalar('train/acry_prcauc', prcs[0], epoch)
            writer.add_scalar('train/chloro_rocauc', rocs[1], epoch)
            writer.add_scalar('train/chloro_prcauc', prcs[1], epoch)
            writer.add_scalar('train/rest_rocauc', rocs[2], epoch)
            writer.add_scalar('train/rest_prcauc', prcs[2], epoch)
            writer.add_scalar('train/HTS_rocauc', rocs[3], epoch)
            writer.add_scalar('train/HTS_prcauc', prcs[3], epoch)

            writer.add_scalar('train/acry_rmse', rmses[0], epoch)
            writer.add_scalar('train/acry_r2', r2s[0], epoch)
            writer.add_scalar('train/chloro_rmse', rmses[1], epoch)
            writer.add_scalar('train/chloro_r2', r2s[1], epoch)
            writer.add_scalar('train/rest_rmse', rmses[2], epoch)
            writer.add_scalar('train/rest_r2', r2s[2], epoch)

            if epoch % 20 == 0:
                print(f"\nepoch: {epoch}, "
                      f"LOSS: {epoch_loss:.3f}"
                      f"\n acry ROC-AUC: {rocs[0]:.3f}, "
                      f"acry PRC-AUC: {prcs[0]:.3f}"
                      f"\n chloro ROC-AUC: {rocs[1]:.3f}, "
                      f"chloro PRC-AUC: {prcs[1]:.3f}"
                      f"\n rest ROC-AUC: {rocs[2]:.3f}, "
                      f"rest PRC-AUC: {prcs[2]:.3f}"
                      f"\n HTS ROC-AUC: {rocs[3]:.3f}, "
                      f"HTS PRC-AUC: {prcs[3]:.3f}"
                      f"\n acry R2: {r2s[0]:.3f}, "
                      f"acry RMSE: {rmses[0]:.3f}"
                      f"\n chloro R2: {r2s[1]:.3f}, "
                      f"chloro RMSE: {rmses[1]:.3f}"
                      f"\n rest R2: {r2s[2]:.3f}, "
                      f"rest RMSE: {rmses[2]:.3f}")
                try:
                    torch.save(
                        mpnn_net.state_dict(),
                        '/rds-d2/user/wjm41/hpc-work/models/' + args.savename +
                        '/model_epoch_' + str(epoch) + '.pt')
                except FileNotFoundError:
                    cmd = 'mkdir /rds-d2/user/wjm41/hpc-work/models/' + args.savename
                    os.system(cmd)
                    torch.save(
                        mpnn_net.state_dict(),
                        '/rds-d2/user/wjm41/hpc-work/models/' + args.savename +
                        '/model_epoch_' + str(epoch) + '.pt')
            if args.test:
                # Evaluate
                mpnn_net.eval()
                preds = []
                labs = []
                for i, (bg, labels) in enumerate(test_loader):
                    labels = labels.to(device)
                    atom_feats = bg.ndata.pop('h').to(device)
                    bond_feats = bg.edata.pop('e').to(device)
                    y_pred = mpnn_net(bg, atom_feats, bond_feats)
                    y_pred = process(y_pred)

                    labels = labels.cpu().numpy()
                    y_pred = y_pred.detach().cpu().numpy()

                    preds.append(y_pred)
                    labs.append(labels)

                labs = np.concatenate(labs, axis=0)
                preds = np.concatenate(preds, axis=0)
                rmses = []
                r2s = []
                rocs = []
                prcs = []
                for ind in reg_inds:

                    rmse = np.sqrt(
                        mean_squared_error(
                            labs[:, ind][~np.isnan(labs[:, ind])],
                            preds[:, ind][~np.isnan(labs[:, ind])]))

                    r2 = r2_score(labs[:, ind][~np.isnan(labs[:, ind])],
                                  preds[:, ind][~np.isnan(labs[:, ind])])
                    rmses.append(rmse)
                    r2s.append(r2)
                for ind in class_inds[:3]:
                    roc = roc_auc_score(labs[:, ind][~np.isnan(labs[:, ind])],
                                        preds[:, ind][~np.isnan(labs[:, ind])])
                    precision, recall, thresholds = precision_recall_curve(
                        labs[:, ind][~np.isnan(labs[:, ind])],
                        preds[:, ind][~np.isnan(labs[:, ind])])
                    prc = auc(recall, precision)
                    rocs.append(roc)
                    prcs.append(prc)
                writer.add_scalar('test/acry_rocauc', rocs[0], epoch)
                writer.add_scalar('test/acry_prcauc', prcs[0], epoch)
                writer.add_scalar('test/chloro_rocauc', rocs[1], epoch)
                writer.add_scalar('test/chloro_prcauc', prcs[1], epoch)
                writer.add_scalar('test/rest_rocauc', rocs[2], epoch)
                writer.add_scalar('test/rest_prcauc', prcs[2], epoch)

                writer.add_scalar('test/acry_rmse', rmses[0], epoch)
                writer.add_scalar('test/acry_r2', r2s[0], epoch)
                writer.add_scalar('test/chloro_rmse', rmses[1], epoch)
                writer.add_scalar('test/chloro_r2', r2s[1], epoch)
                writer.add_scalar('test/rest_rmse', rmses[2], epoch)
                writer.add_scalar('test/rest_r2', r2s[2], epoch)
                if epoch == (args.n_epochs):
                    print(
                        f"\n======================== TEST ========================"
                        f"\n acry ROC-AUC: {rocs[0]:.3f}, "
                        f"acry PRC-AUC: {prcs[0]:.3f}"
                        f"\n chloro ROC-AUC: {rocs[1]:.3f}, "
                        f"chloro PRC-AUC: {prcs[1]:.3f}"
                        f"\n rest ROC-AUC: {rocs[2]:.3f}, "
                        f"rest PRC-AUC: {prcs[2]:.3f}"
                        f"\n acry R2: {r2s[0]:.3f}, "
                        f"acry RMSE: {rmses[0]:.3f}"
                        f"\n chloro R2: {r2s[1]:.3f}, "
                        f"chloro RMSE: {rmses[1]:.3f}"
                        f"\n rest R2: {r2s[2]:.3f}, "
                        f"rest RMSE: {rmses[2]:.3f}")
                    roc_list.append(rocs)
                    prc_list.append(prcs)
                    r2_list.append(r2s)
                    rmse_list.append(rmses)
        torch.save(
            mpnn_net.state_dict(), '/rds-d2/user/wjm41/hpc-work/models/' +
            args.savename + '/model_epoch_final.pt')
    if args.test:
        roc_list = np.array(roc_list).T
        prc_list = np.array(prc_list).T
        r2_list = np.array(r2_list).T
        rmse_list = np.array(rmse_list).T
        print("\n ACRY")
        print("R^2: {:.4f} +- {:.4f}".format(
            np.mean(r2_list[0]),
            np.std(r2_list[0]) / np.sqrt(len(r2_list[0]))))
        print("RMSE: {:.4f} +- {:.4f}".format(
            np.mean(rmse_list[0]),
            np.std(rmse_list[0]) / np.sqrt(len(rmse_list[0]))))
        print("ROC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(roc_list[0]),
            np.std(roc_list[0]) / np.sqrt(len(roc_list[0]))))
        print("PRC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(prc_list[0]),
            np.std(prc_list[0]) / np.sqrt(len(prc_list[0]))))
        print("\n CHLORO")
        print("R^2: {:.4f} +- {:.4f}".format(
            np.mean(r2_list[1]),
            np.std(r2_list[1]) / np.sqrt(len(r2_list[1]))))
        print("RMSE: {:.4f} +- {:.4f}".format(
            np.mean(rmse_list[1]),
            np.std(rmse_list[1]) / np.sqrt(len(rmse_list[1]))))
        print("ROC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(roc_list[1]),
            np.std(roc_list[1]) / np.sqrt(len(roc_list[1]))))
        print("PRC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(prc_list[1]),
            np.std(prc_list[1]) / np.sqrt(len(prc_list[1]))))
        print("\n REST")
        print("R^2: {:.4f} +- {:.4f}".format(
            np.mean(r2_list[2]),
            np.std(r2_list[2]) / np.sqrt(len(r2_list[2]))))
        print("RMSE: {:.4f} +- {:.4f}".format(
            np.mean(rmse_list[2]),
            np.std(rmse_list[2]) / np.sqrt(len(rmse_list[2]))))
        print("ROC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(roc_list[2]),
            np.std(roc_list[2]) / np.sqrt(len(roc_list[2]))))
        print("PRC-AUC: {:.3f} +- {:.3f}".format(
            np.mean(prc_list[2]),
            np.std(prc_list[2]) / np.sqrt(len(prc_list[2]))))
    print('use GPU')
    device = 'cuda'
else:
    print('use CPU')
    device = 'cpu'

mols = [Chem.MolFromSmiles(s) for s in train_smiles]

atom_featurizer = CanonicalAtomFeaturizer(atom_data_field='h')
n_feats = atom_featurizer.feat_size('h')
bond_featurizer = CanonicalBondFeaturizer(bond_data_field='h')
b_feat = bond_featurizer.feat_size('h')

train_graph = [
    mol_to_bigraph(mol,
                   node_featurizer=atom_featurizer,
                   edge_featurizer=bond_featurizer) for mol in mols
]

model = AttentiveFPPredictor(node_feat_size=n_feats,
                             edge_feat_size=b_feat,
                             num_layers=2,
                             num_timesteps=2,
                             graph_feat_size=200,
                             n_tasks=1,
                             dropout=0.2)
#model = AttentiveFPGNN(n_feats,b_feat,2,200)
model = model.to(device)
print(model)