def evaluate_network_sparse(model, device, data_loader, epoch):
    model.eval()
    epoch_test_loss = 0
    epoch_test_ROC = 0
    with torch.no_grad():
        list_scores = []
        list_labels = []
        for iter, (batch_graphs, batch_labels, batch_snorm_n,
                   batch_snorm_e) in enumerate(data_loader):
            batch_x = batch_graphs.ndata['feat'].to(device)
            batch_e = batch_graphs.edata['feat'].to(device)
            batch_snorm_e = batch_snorm_e.to(device)
            batch_snorm_n = batch_snorm_n.to(device)
            batch_labels = batch_labels.to(device)
            batch_graphs = batch_graphs.to(device)
            batch_scores = model.forward(batch_graphs, batch_x, batch_e,
                                         batch_snorm_n, batch_snorm_e)
            loss = model.loss(batch_scores, batch_labels)
            epoch_test_loss += loss.detach().item()
            list_scores.append(batch_scores.detach())
            list_labels.append(batch_labels.detach().unsqueeze(-1))

        epoch_test_loss /= (iter + 1)
        evaluator = Evaluator(name='ogbg-molhiv')
        epoch_test_ROC = evaluator.eval({
            'y_pred': torch.cat(list_scores),
            'y_true': torch.cat(list_labels)
        })['rocauc']

    return epoch_test_loss, epoch_test_ROC
def train_epoch_sparse(model, optimizer, device, data_loader, epoch):
    model.train()
    epoch_loss = 0
    epoch_train_ROC = 0
    list_scores = []
    list_labels = []
    for iter, (batch_graphs, batch_labels, batch_snorm_n,
               batch_snorm_e) in enumerate(data_loader):
        batch_x = batch_graphs.ndata['feat'].to(device)  # num x feat
        batch_e = batch_graphs.edata['feat'].to(device)
        batch_snorm_e = batch_snorm_e.to(device)
        batch_snorm_n = batch_snorm_n.to(device)
        batch_labels = batch_labels.to(device)
        batch_graphs = batch_graphs.to(device)
        optimizer.zero_grad()
        batch_scores = model.forward(batch_graphs, batch_x, batch_e,
                                     batch_snorm_n, batch_snorm_e)
        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        list_scores.append(batch_scores.detach())
        list_labels.append(batch_labels.detach().unsqueeze(-1))

    epoch_loss /= (iter + 1)
    evaluator = Evaluator(name='ogbg-molhiv')
    epoch_train_ROC = evaluator.eval({
        'y_pred': torch.cat(list_scores),
        'y_true': torch.cat(list_labels)
    })['rocauc']

    return epoch_loss, epoch_train_ROC, optimizer
Beispiel #3
0
def test(loader):
    model.eval()
    evaluator = Evaluator(name='ogbg-molhiv')
    list_pred = []
    list_labels = []
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, None, data.batch)
        list_pred.append(out)
        list_labels.append(data.y)
    epoch_test_ROC = evaluator.eval({
        'y_pred': torch.cat(list_pred),
        'y_true': torch.cat(list_labels)
    })['rocauc']
    return epoch_test_ROC
def train_epoch_sparse(model, optimizer, device, data_loader, epoch,
                       distortion):
    model.train()
    epoch_loss = 0
    epoch_train_ROC = 0
    list_scores = []
    list_labels = []
    for iter, (batch_graphs, batch_labels, batch_snorm_n,
               batch_snorm_e) in enumerate(data_loader):
        batch_x = batch_graphs.ndata['feat'].to(device)  # num x feat
        batch_e = batch_graphs.edata['feat'].to(device)
        batch_snorm_e = batch_snorm_e.to(device)
        batch_snorm_n = batch_snorm_n.to(device)
        batch_labels = batch_labels.to(device)
        if distortion > 1e-7:
            batch_graphs_eig = batch_graphs.ndata['eig'].clone()
            dist = (torch.rand(batch_x[:, 0].shape) - 0.5) * 2 * distortion
            batch_graphs.ndata['eig'][:, 1] = torch.mul(
                dist,
                torch.mean(torch.abs(batch_graphs_eig[:, 1]),
                           dim=-1,
                           keepdim=True)) + batch_graphs_eig[:, 1]
            batch_graphs.ndata['eig'][:, 2] = torch.mul(
                dist,
                torch.mean(torch.abs(batch_graphs_eig[:, 2]),
                           dim=-1,
                           keepdim=True)) + batch_graphs_eig[:, 2]

        optimizer.zero_grad()
        batch_scores = model.forward(batch_graphs, batch_x, batch_e,
                                     batch_snorm_n, batch_snorm_e)
        loss = model.loss(batch_scores, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        list_scores.append(batch_scores.detach())
        list_labels.append(batch_labels.detach().unsqueeze(-1))
        if distortion > 1e-7:
            batch_graphs.ndata['eig'] = batch_graphs_eig.detach()

    epoch_loss /= (iter + 1)
    evaluator = Evaluator(name='ogbg-molhiv')
    epoch_train_ROC = evaluator.eval({
        'y_pred': torch.cat(list_scores),
        'y_true': torch.cat(list_labels)
    })['rocauc']

    return epoch_loss, epoch_train_ROC, optimizer
def train_epoch(model, optimizer, device, data_loader, epoch):
    model.train()
    epoch_loss = 0
    epoch_train_AP = 0
    list_scores = []
    list_labels = []
    for iter, (batch_graphs, batch_targets) in enumerate(data_loader):
        batch_graphs = batch_graphs.to(device)
        batch_x = batch_graphs.ndata['feat'].to(device)  # num x feat
        batch_e = batch_graphs.edata['feat'].to(device)
        batch_targets = batch_targets.to(device)
        optimizer.zero_grad()
        try:
            batch_lap_pos_enc = batch_graphs.ndata['lap_pos_enc'].to(device)
            sign_flip = torch.rand(batch_lap_pos_enc.size(1)).to(device)
            sign_flip[sign_flip >= 0.5] = 1.0
            sign_flip[sign_flip < 0.5] = -1.0
            batch_lap_pos_enc = batch_lap_pos_enc * sign_flip.unsqueeze(0)
        except:
            batch_lap_pos_enc = None

        try:
            batch_wl_pos_enc = batch_graphs.ndata['wl_pos_enc'].to(device)
        except:
            batch_wl_pos_enc = None

        batch_scores = model.forward(batch_graphs, batch_x, batch_e,
                                     batch_lap_pos_enc, batch_wl_pos_enc)
        is_labeled = batch_targets == batch_targets
        loss = model.loss(batch_scores[is_labeled],
                          batch_targets.float()[is_labeled])
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        list_scores.append(batch_scores.detach().cpu())
        list_labels.append(batch_targets.detach().cpu())

    epoch_loss /= (iter + 1)
    evaluator = Evaluator(name='ogbg-molpcba')
    epoch_train_AP = evaluator.eval({
        'y_pred': torch.cat(list_scores),
        'y_true': torch.cat(list_labels)
    })['ap']

    return epoch_loss, epoch_train_AP, optimizer
Beispiel #6
0
 def eval_on(self, loader, trainer):
     results_dict = super().eval_on(loader, trainer)
     evaluator = GraphPropEvaluator(name=self.task_name)
     y_trues = []
     y_preds = []
     for batch in loader:
         if trainer.on_gpu:
             batch = batch.to("cuda")
         y_preds.append(self.model(batch).cpu().detach().numpy())
         y_trues.append(batch.y.cpu().detach().numpy())
     y_trues = np.concatenate(y_trues, axis=0)
     y_preds = np.concatenate(y_preds, axis=0)
     results_dict.update(
         evaluator.eval({
             "y_true": y_trues,
             "y_pred": y_preds
         }))
     return results_dict
def evaluate_network(model, device, data_loader, epoch):
    model.eval()
    epoch_test_loss = 0
    epoch_test_AP = 0
    with torch.no_grad():
        list_scores = []
        list_labels = []
        for iter, (batch_graphs, batch_targets) in enumerate(data_loader):
            batch_graphs = batch_graphs.to(device)
            batch_x = batch_graphs.ndata['feat'].to(device)
            batch_e = batch_graphs.edata['feat'].to(device)
            batch_targets = batch_targets.to(device)
            try:
                batch_lap_pos_enc = batch_graphs.ndata['lap_pos_enc'].to(
                    device)
            except:
                batch_lap_pos_enc = None

            try:
                batch_wl_pos_enc = batch_graphs.ndata['wl_pos_enc'].to(device)
            except:
                batch_wl_pos_enc = None

            batch_scores = model.forward(batch_graphs, batch_x, batch_e,
                                         batch_lap_pos_enc, batch_wl_pos_enc)
            is_labeled = batch_targets == batch_targets
            loss = model.loss(batch_scores[is_labeled],
                              batch_targets.float()[is_labeled])
            epoch_test_loss += loss.detach().item()
            list_scores.append(batch_scores.detach().cpu())
            list_labels.append(batch_targets.detach().cpu())

        epoch_test_loss /= (iter + 1)
        evaluator = Evaluator(name='ogbg-molpcba')
        epoch_test_AP = evaluator.eval({
            'y_pred': torch.cat(list_scores),
            'y_true': torch.cat(list_labels)
        })['ap']

    return epoch_test_loss, epoch_test_AP
Beispiel #8
0
def run(args):
    from ogb.graphproppred import DglGraphPropPredDataset, Evaluator, collate_dgl
    from torch.utils.data import DataLoader

    dataset = DglGraphPropPredDataset(name="ogbg-molhiv")

    import os
    if not os.path.exists("heterographs.bin"):
        dataset.graphs = [hpno.heterograph(graph) for graph in dataset.graphs]
        from dgl.data.utils import save_graphs
        save_graphs("heterographs.bin", dataset.graphs)
    else:
        from dgl.data.utils import load_graphs
        dataset.graphs = load_graphs("heterographs.bin")[0]

    evaluator = Evaluator(name="ogbg-molhiv")
    in_features = 9
    out_features = 1

    split_idx = dataset.get_idx_split()
    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=128, drop_last=True, shuffle=True, collate_fn=collate_dgl)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=len(split_idx["valid"]), shuffle=False, collate_fn=collate_dgl)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=len(split_idx["test"]), shuffle=False, collate_fn=collate_dgl)

    model = hpno.HierarchicalPathNetwork(
        in_features=in_features,
        out_features=args.hidden_features,
        hidden_features=args.hidden_features,
        depth=args.depth,
        readout=hpno.GraphReadout(
            in_features=args.hidden_features,
            out_features=out_features,
            hidden_features=args.hidden_features,
        )
    )


    if torch.cuda.is_available():
        model = model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.5, patience=20)

    for idx_epoch in range(args.n_epochs):
        print(idx_epoch, flush=True)
        model.train()
        for g, y in train_loader:
            y = y.float()
            if torch.cuda.is_available():
                g = g.to("cuda:0")
                y = y.cuda()
            optimizer.zero_grad()
            y_hat = model.forward(g, g.nodes['n1'].data["feat"].float())
            loss = torch.nn.BCELoss()(
                input=y_hat.sigmoid(),
                target=y,
            )
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            g, y = next(iter(valid_loader))
            y = y.float()
            if torch.cuda.is_available():
                g = g.to("cuda:0")
                y = y.cuda()
            y_hat = model.forward(g, g.nodes['n1'].data["feat"].float())
            loss = torch.nn.BCELoss()(
                input=y_hat.sigmoid(),
                target=y,
            )
            scheduler.step(loss)

        if optimizer.param_groups[0]["lr"] <= 0.01 * args.learning_rate: break

    model = model.cpu()
    g, y = next(iter(valid_loader))
    rocauc_vl = evaluator.eval(
        {
            "y_true": y.float(),
            "y_pred": model.forward(g, g.nodes['n1'].data["feat"].float()).sigmoid()
        }
    )["rocauc"]

    g, y = next(iter(test_loader))
    rocauc_te = evaluator.eval(
        {
            "y_true": y.float(),
            "y_pred": model.forward(g, g.nodes['n1'].data["feat"].float()).sigmoid()
        }
    )["rocauc"]

    import pandas as pd
    df = pd.DataFrame(
        {
            args.data: {
                "rocauc_te": rocauc_te,
                "rocauc_vl": rocauc_vl,
            }
        }
    )

    df.to_csv("%s.csv" % args.out)
step = loss = 0
for batch in loader_tr:
    step += 1
    loss += train_step(*batch)
    if step == loader_tr.steps_per_epoch:
        step = 0
        print("Loss: {}".format(loss / loader_tr.steps_per_epoch))
        loss = 0

################################################################################
# Evaluate model
################################################################################
print("Testing model")
evaluator = Evaluator(name=dataset_name)
y_true = []
y_pred = []
for batch in loader_te:
    inputs, target = batch
    p = model(inputs, training=False)
    y_true.append(target)
    y_pred.append(p.numpy())

y_true = np.vstack(y_true)
y_pred = np.vstack(y_pred)
model_loss = loss_fn(y_true, y_pred)
ogb_score = evaluator.eval({"y_true": y_true, "y_pred": y_pred})

print(
    "Done. Test loss: {:.4f}. ROC-AUC: {:.2f}".format(model_loss, ogb_score["rocauc"])
)
Beispiel #10
0
def main(_):
    tf.keras.mixed_precision.set_global_policy("float16" if FLAGS.dtype == 'float16' else "float32")

    dset_name = 'ogbg-molhiv'
    dataset = GraphPropPredDataset(name=dset_name, )
    split_idx = dataset.get_idx_split()
    train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

    ds = data.get_tf_dataset(FLAGS.batch_size, [dataset[idx] for idx in train_idx], shuffle=True)
    val_ds = data.get_tf_dataset(FLAGS.batch_size, [dataset[idx] for idx in valid_idx], shuffle=False)
    strategy = xpu.configure_and_get_strategy()

    if FLAGS.total_batch_size is not None:
        gradient_accumulation_factor = FLAGS.total_batch_size // FLAGS.batch_size
    else:
        gradient_accumulation_factor = 1

    # pre-calculated number of steps per epoch (note: will vary somewhat for training, due to packing,
    #  but is found to be fairly consistent)
    steps = {
        32: (1195, 162, 148),
        64: (585, 80, 73),
        128: (288, 40, 37),
        256: (143, 20, 18)
    }
    try:
        steps_per_epoch, val_steps_per_epoch, test_steps_per_epoch = steps[FLAGS.batch_size]
    except KeyError:
        print("Batch size should have the number of steps defined")
        raise KeyError()

    # need the steps per epoch to be divisible by the gradient accumulation factor
    steps_per_epoch = gradient_accumulation_factor * (steps_per_epoch // gradient_accumulation_factor)

    # we apply a linear scaling rule for learning rate with batch size, which we benchmark against BS=128
    batch_size = FLAGS.total_batch_size or FLAGS.batch_size
    lr = FLAGS.lr * batch_size / 128

    with strategy.scope():
        model = create_model()
        utils.print_trainable_variables(model)

        losses = tf.keras.losses.BinaryCrossentropy()
        if FLAGS.opt.lower() == 'sgd':
            opt = tf.keras.optimizers.SGD(learning_rate=lr)
        elif FLAGS.opt.lower() == 'adam':
            opt = tf.keras.optimizers.Adam(learning_rate=lr)
        else:
            raise NotImplementedError()

        callbacks = []

        if not os.path.isdir(FLAGS.model_dir):
            os.makedirs(FLAGS.model_dir)
        # randomly named directory
        model_dir = os.path.join(FLAGS.model_dir, str(uuid.uuid4()))

        print(f"Saving weights to {model_dir}")
        model_path = os.path.join(model_dir, 'model')

        callbacks.append(tf.keras.callbacks.ModelCheckpoint(
            model_path, monitor="val_loss", verbose=1, save_best_only=True,
            save_weights_only=True, mode="min", save_freq="epoch")
        )

        callbacks.append(ThroughputCallback(
            samples_per_epoch=steps_per_epoch * FLAGS.batch_size * gradient_accumulation_factor))
        if FLAGS.reduce_lr_on_plateau_patience > 0:
            callbacks.append(tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss', mode='min', factor=FLAGS.reduce_lr_on_plateau_factor,
                patience=FLAGS.reduce_lr_on_plateau_patience, min_lr=1e-8, verbose=1)
            )

        if FLAGS.early_stopping_patience > 0:
            print(f"Training will stop early after {FLAGS.early_stopping_patience} epochs without improvement.")
            callbacks.append(
                tf.keras.callbacks.EarlyStopping(
                    monitor='val_loss', min_delta=0, patience=FLAGS.early_stopping_patience,
                    verbose=1, mode='min', baseline=None, restore_best_weights=False)
            )

        # weighted metrics are used because of the batch packing
        model.compile(optimizer=opt, loss=losses,
                      weighted_metrics=[tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.AUC()],
                      steps_per_execution=steps_per_epoch)

        # if the total batch size exceeds the compute batch size
        model.set_gradient_accumulation_options(gradient_accumulation_steps_per_replica=gradient_accumulation_factor)

        model.fit(ds,
                  steps_per_epoch=steps_per_epoch,
                  epochs=FLAGS.epochs,
                  validation_data=val_ds,
                  validation_steps=val_steps_per_epoch,
                  callbacks=callbacks
                  )

        # we will use the official AUC evaluator from the OGB repo, not the keras one
        model.load_weights(model_path)
        print("Loaded best validation weights for evaluation")

        evaluator = Evaluator(name='ogbg-molhiv')
        for test_or_val, idx, steps in zip(
                ('validation', 'test'),
                (valid_idx, test_idx),
                (val_steps_per_epoch, test_steps_per_epoch)):
            prediction, ground_truth = get_predictions(model, dataset, idx, steps)
            result = evaluator.eval({'y_true': ground_truth[:, None], 'y_pred': prediction[:, None]})

            print(f'Final {test_or_val} ROC-AUC {result["rocauc"]:.3f}')
Beispiel #11
0
def WEGL(dataset,
         num_hidden_layers,
         node_embedding_sizes,
         final_node_embedding,
         num_pca_components=20,
         num_experiments=10,
         classifiers=['RF'],
         random_seed=0,
         device='cpu'):
    """
    # The WEGL pipeline
    
    Inputs:
        - dataset: dataset object
        - num_hidden_layers: number of diffusion layers
        - node_embedding_sizes: node embedding dimensionality created by the AtomEncoder module
        - final_node_embedding: final node embedding type $\in$ {'concat', 'avg', 'final'}
        - num_pca_components: number of PCA components applied on node embeddings. -1 means no PCA.
        - num_experiments: number of experiments with different random seeds
        - classifiers: list of downstream classifiers
        (currently random forest ('RF') only; other classifiers, e.g., SVM, can be added if desired
        - random_seed: the random seed
        - device # the device to run the diffusion over ('cpu'/'cuda')
        
    Outputs:
        - A table containing the classification results
        
    """

    # Set the random seed
    random.seed(random_seed)
    np.random.seed(random_seed)

    # Create data loaders
    split_idx = dataset.get_idx_split()  # train/val/test split
    loader_dict = {}
    for phase in split_idx:
        batch_size = 32
        loader_dict[phase] = DataLoader(dataset[split_idx[phase]],
                                        batch_size=batch_size,
                                        shuffle=False)

    # prepare the output table
    results_table = PrettyTable()
    results_table.title = 'Final ROC-AUC(%) results for the {0} dataset with \'{1}\' node embedding and one-hot 13-dim edge embedding'.\
                               format(dataset.name, final_node_embedding)

    results_table.field_names = [
        'Classifier', '# Diffusion Layers', 'Node Embedding Size', 'Train.',
        'Val.', 'Test'
    ]

    n_jobs = 14
    verbose = 0

    for L, F in itertools.product(num_hidden_layers, node_embedding_sizes):
        print('*' * 100)
        print('# diffusion layers = {0}, node embedding size = {1}, node embedding mode: {2}\n'.\
              format(L, F, final_node_embedding))

        # create an instance of the diffusion object
        diffusion = Diffusion(
            num_hidden_layers=L,
            final_node_embedding=final_node_embedding).to(device)
        diffusion.eval()

        # create the node encoder
        node_feature_encoder = AtomEncoder(F).to(device)
        node_feature_encoder.eval()

        phases = list(
            loader_dict.keys()
        )  # determine different partitions of data ('train', 'valid' and 'test')

        # pass the all the graphs in the data through the GNN
        X = defaultdict(list)
        Y = defaultdict(list)

        for phase in phases:
            print('Now diffusing the ' + phase + ' data ...')
            for i, batch in enumerate(tqdm(loader_dict[phase])):
                batch = batch.to(device)

                # encode node features
                batch.x = node_feature_encoder(batch.x)

                # encode edge features
                batch.edge_attr = BondEncoderOneHot(batch.edge_attr)

                # add virtual nodes
                batch_size = len(batch.y)
                num_original_nodes = batch.x.size(0)
                batch.batch = torch.cat(
                    (batch.batch, torch.Tensor(range(batch_size)).to(
                        batch.batch.dtype)),
                    dim=0)

                # make the initial features of all virtual nodes zero
                batch.x = torch.cat(
                    (batch.x, batch.x.new_zeros(batch_size, batch.x.size(1))),
                    dim=0)

                # add edges between all nodes in each graph and the virtual node for that graph
                for g in range(batch_size):
                    node_indices = np.where(
                        batch.batch ==
                        g)[0][:-1]  # last node is the virtual node
                    virtual_edges_one_way = np.array([
                        node_indices,
                        (num_original_nodes + g) * np.ones_like(node_indices)
                    ])
                    virtual_edges_two_ways = np.concatenate(
                        (virtual_edges_one_way,
                         np.take(virtual_edges_one_way, [1, 0], axis=0)),
                        axis=1)

                    batch.edge_index = torch.cat(
                        (batch.edge_index,
                         torch.Tensor(virtual_edges_two_ways).to(
                             batch.edge_index.dtype)),
                        dim=1)

                    # make the initial edge features of all edges to/from virtual nodes all 1 / number of graph nodes
                    batch.edge_attr = torch.cat(
                        (batch.edge_attr,
                         batch.edge_attr.new_ones(2 * len(node_indices),
                                                  batch.edge_attr.size(1)) /
                         len(node_indices)),
                        dim=0)

                # pass the data through the diffusion process
                z = diffusion(batch)

                batch_indices = batch.batch.cpu()
                for b in range(batch_size):
                    node_indices = np.where(batch_indices == b)[0]
                    X[phase].append(z[node_indices].detach().cpu().numpy())

                Y[phase].extend(
                    batch.y.detach().cpu().numpy().flatten().tolist())

        # standardize the features based on mean and std of the training data
        ss = StandardScaler()
        ss.fit(np.concatenate(X['train'], 0))
        for phase in phases:
            for i in range(len(X[phase])):
                X[phase][i] = ss.transform(X[phase][i])

        # apply PCA if needed
        if num_pca_components > 0:
            print('Now running PCA ...')
            pca = PCA(n_components=num_pca_components,
                      random_state=random_seed)
            pca.fit(np.concatenate(X['train'], 0))
            for phase in phases:
                for i in range(len(X[phase])):
                    X[phase][i] = pca.transform(X[phase][i])

            # plot the variance % explained by PCA components
            plt.plot(np.arange(1, num_pca_components + 1),
                     pca.explained_variance_ratio_, 'o--')
            plt.grid(True)
            plt.xlabel('Principal component')
            plt.ylabel('Eigenvalue')
            plt.xticks(np.arange(1, num_pca_components + 1, step=2))
            plt.show()

        # number of samples in the template distribution
        N = int(round(np.asarray([x.shape[0] for x in X['train']]).mean()))

        # derive the template distribution using K-means
        print('Now running k-means for deriving the template ...\n')
        kmeans = KMeans(n_clusters=N,
                        verbose=verbose,
                        random_state=random_seed)
        kmeans.fit(np.concatenate(X['train'], 0))
        template = kmeans.cluster_centers_

        # calculate the final graph embeddings based on LOT
        V = defaultdict(list)
        for phase in phases:
            print('Now deriving the final graph embeddings for the ' + phase +
                  ' data ...')
            for x in tqdm(X[phase]):
                M = x.shape[0]
                C = ot.dist(x, template)
                b = np.ones((N, )) / float(N)
                a = np.ones((M, )) / float(M)
                p = ot.emd(a, b, C)  # exact linear program
                V[phase].append(np.matmul((N * p).T, x) - template)
            V[phase] = np.stack(V[phase])

        # create the parameter grid for random forest
        param_grid_RF = {
            'max_depth': [None],
            'min_samples_leaf': [1, 2, 5],
            'min_samples_split': [2, 5, 10],
            'n_estimators': [25, 50, 100, 150, 200]
        }

        param_grid_all = {'RF': param_grid_RF}

        # load the ROC-AUC evaluator
        evaluator = Evaluator(name=dataset.name)

        # run the classifier
        print('Now running the classifiers ...')
        for classifier in classifiers:
            if classifier not in param_grid_all:
                print('Classifier {} not supported! Skipping ...'.format(
                    classifier))
                continue

            param_grid = param_grid_all[classifier]

            # determine train and validation index split for grid search
            test_fold = [-1] * len(V['train']) + [0] * len(V['valid'])
            ps = PredefinedSplit(test_fold)

            # concatenate train and validation datasets
            X_grid_search = np.concatenate((V['train'], V['valid']), axis=0)
            X_grid_search = X_grid_search.reshape(X_grid_search.shape[0], -1)
            Y_grid_search = np.concatenate((Y['train'], Y['valid']), axis=0)

            results = defaultdict(list)
            for experiment in range(num_experiments):

                # Create a base model
                if classifier == 'RF':
                    model = RandomForestClassifier(n_jobs=n_jobs,
                                                   class_weight='balanced',
                                                   random_state=random_seed +
                                                   experiment)

                # Instantiate the grid search model
                grid_search = GridSearchCV(estimator=model,
                                           param_grid=param_grid,
                                           cv=ps,
                                           n_jobs=n_jobs,
                                           verbose=verbose,
                                           refit=False)

                # Fit the grid search to the data
                grid_search.fit(X_grid_search, Y_grid_search)

                # Fit the model with best parameters on the training data (again)
                for param in grid_search.best_params_:
                    model.param = grid_search.best_params_[param]
                model.fit(V['train'].reshape(V['train'].shape[0], -1),
                          Y['train'])

                # Evaluate the performance
                for phase in phases:
                    pred_probs = model.predict_proba(V[phase].reshape(
                        V[phase].shape[0], -1))
                    input_dict = {
                        'y_true': np.array(Y[phase]).reshape(-1, 1),
                        'y_pred': pred_probs[:, 1].reshape(-1, 1)
                    }
                    result_dict = evaluator.eval(input_dict)
                    results[phase].append(result_dict['rocauc'])

                print('experiment {0}/{1} for {2} completed ...'.format\
                      (experiment+1, num_experiments, classifier))

            results_table.add_row([classifier, str(L), str(F)] + ['{0:.2f} $\pm$ {1:.2f}'.\
                                      format(100 * np.mean(results[phase]), \
                                             100 * np.std(results[phase])) for phase in phases])

    print('\n\n' + results_table.title)
    print(results_table)
    return results_table
Beispiel #12
0
model_loss = 0
for batch in loader_tr:
    outs = train_step(*batch)
    model_loss += outs
    current_batch += 1
    if current_batch == loader_tr.steps_per_epoch:
        print('Loss: {}'.format(model_loss / loader_tr.steps_per_epoch))
        model_loss = 0
        current_batch = 0

################################################################################
# EVALUATE MODEL
################################################################################
print('Testing model')
evaluator = Evaluator(name=dataset_name)
y_true = []
y_pred = []
for batch in loader_te:
    inputs, target = batch
    p = model(inputs, training=False)
    y_true.append(target)
    y_pred.append(p.numpy())

y_true = np.vstack(y_true)
y_pred = np.vstack(y_pred)
model_loss = loss_fn(y_true, y_pred)
ogb_score = evaluator.eval({'y_true': y_true, 'y_pred': y_pred})

print('Done. Test loss: {:.4f}. ROC-AUC: {:.2f}'
      .format(model_loss, ogb_score['rocauc']))
def run(rank, world_size: int, dataset_name: str, root: str):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)

    dataset = Dataset(dataset_name,
                      root,
                      pre_transform=T.ToSparseTensor(attr='edge_attr'))
    split_idx = dataset.get_idx_split()
    evaluator = Evaluator(dataset_name)

    train_dataset = dataset[split_idx['train']]
    train_sampler = DistributedSampler(train_dataset,
                                       num_replicas=world_size,
                                       rank=rank)
    train_loader = DataLoader(train_dataset,
                              batch_size=128,
                              sampler=train_sampler)

    torch.manual_seed(12345)
    model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank)
    model = DistributedDataParallel(model, device_ids=[rank])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.BCEWithLogitsLoss()

    if rank == 0:
        val_loader = DataLoader(dataset[split_idx['valid']], batch_size=256)
        test_loader = DataLoader(dataset[split_idx['test']], batch_size=256)

    for epoch in range(1, 51):
        model.train()

        total_loss = 0
        for data in train_loader:
            data = data.to(rank)
            optimizer.zero_grad()
            logits = model(data.x, data.adj_t, data.batch)
            loss = criterion(logits, data.y.to(torch.float))
            loss.backward()
            optimizer.step()
            total_loss += float(loss) * logits.size(0)
        loss = total_loss / len(train_loader.dataset)

        dist.barrier()

        if rank == 0:  # We evaluate on a single GPU for now.
            model.eval()

            y_pred, y_true = [], []
            for data in val_loader:
                data = data.to(rank)
                with torch.no_grad():
                    y_pred.append(model.module(data.x, data.adj_t, data.batch))
                    y_true.append(data.y)
            val_rocauc = evaluator.eval({
                'y_pred': torch.cat(y_pred, dim=0),
                'y_true': torch.cat(y_true, dim=0),
            })['rocauc']

            y_pred, y_true = [], []
            for data in test_loader:
                data = data.to(rank)
                with torch.no_grad():
                    y_pred.append(model.module(data.x, data.adj_t, data.batch))
                    y_true.append(data.y)
            test_rocauc = evaluator.eval({
                'y_pred': torch.cat(y_pred, dim=0),
                'y_true': torch.cat(y_true, dim=0),
            })['rocauc']

            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
                  f'Val: {val_rocauc:.4f}, Test: {test_rocauc:.4f}')

        dist.barrier()

    dist.destroy_process_group()
class OGBPCBADataset(WILDSDataset):
    """
    The OGB-molpcba dataset.
    This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet.

    Supported `split_scheme`:
        'official' or 'scaffold', which are equivalent

    Input (x):
        Molecular graphs represented as Pytorch Geometric data objects

    Label (y):
        y represents 128-class binary labels.

    Metadata:
        - scaffold
            Each molecule is annotated with the scaffold ID that the molecule is assigned to.

    Website:
        https://ogb.stanford.edu/docs/graphprop/#ogbg-mol

    Original publication:
        @article{hu2020ogb,
            title={Open Graph Benchmark: Datasets for Machine Learning on Graphs},
            author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}},
            journal={arXiv preprint arXiv:2005.00687},
            year={2020}
        }

        @article{wu2018moleculenet,
            title={MoleculeNet: a benchmark for molecular machine learning},
            author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}},
            journal={Chemical science},
            volume={9},
            number={2},
            pages={513--530},
            year={2018},
            publisher={Royal Society of Chemistry}
        }

    License:
        This dataset is distributed under the MIT license.
        https://github.com/snap-stanford/ogb/blob/master/LICENSE
    """

    _dataset_name = 'ogbg-molpcba'
    _versions_dict = {'1.0': {'download_url': None, 'compressed_size': None}}

    def __init__(self,
                 version=None,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):
        self._version = version
        if version is not None:
            raise ValueError(
                'Versioning for OGB-MolPCBA is handled through the OGB package. Please set version=none.'
            )
        # internally call ogb package
        self.ogb_dataset = PygGraphPropPredDataset(name='ogbg-molpcba',
                                                   root=root_dir)

        # set variables
        self._data_dir = self.ogb_dataset.root
        if split_scheme == 'official':
            split_scheme = 'scaffold'
        self._split_scheme = split_scheme
        self._y_type = 'float'  # although the task is binary classification, the prediction target contains nan value, thus we need float
        self._y_size = self.ogb_dataset.num_tasks
        self._n_classes = self.ogb_dataset.__num_classes__

        self._split_array = torch.zeros(len(self.ogb_dataset)).long()
        split_idx = self.ogb_dataset.get_idx_split()
        self._split_array[split_idx['train']] = 0
        self._split_array[split_idx['valid']] = 1
        self._split_array[split_idx['test']] = 2

        self._y_array = self.ogb_dataset.data.y

        self._metadata_fields = ['scaffold']

        metadata_file_path = os.path.join(self.ogb_dataset.root, 'raw',
                                          'scaffold_group.npy')
        if not os.path.exists(metadata_file_path):
            download_url(
                'https://snap.stanford.edu/ogb/data/misc/ogbg_molpcba/scaffold_group.npy',
                os.path.join(self.ogb_dataset.root, 'raw'))
        self._metadata_array = torch.from_numpy(
            np.load(metadata_file_path)).reshape(-1, 1).long()

        if torch_geometric.__version__ >= '1.7.0':
            self._collate = PyGCollater(follow_batch=[], exclude_keys=[])
        else:
            self._collate = PyGCollater(follow_batch=[])

        self._metric = Evaluator('ogbg-molpcba')

        super().__init__(root_dir, download, split_scheme)

    def get_input(self, idx):
        return self.ogb_dataset[int(idx)]

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (FloatTensor): Binary logits from a model
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels. 
                                        Only None is supported because OGB Evaluators accept binary logits
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        assert prediction_fn is None, "OGBPCBADataset.eval() does not support prediction_fn. Only binary logits accepted"
        input_dict = {"y_true": y_true, "y_pred": y_pred}
        results = self._metric.eval(input_dict)

        return results, f"Average precision: {results['ap']:.3f}\n"
Beispiel #15
0
class OGBGMolpcbaModel(LightningModule):
    def __init__(self,
                 architecture: str = "GCN",
                 num_node_features: int = 300,
                 activation: str = "prelu",
                 num_conv_layers: int = 3,
                 conv_size: int = 256,
                 pool_method: str = "add",
                 lin1_size: int = 128,
                 lin2_size: int = 64,
                 output_size: int = 128,
                 lr: float = 0.001,
                 weight_decay: float = 0,
                 **kwargs):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

        # init node embedding layer
        self.atom_encoder = AtomEncoder(emb_dim=self.hparams.num_node_features)
        # self.bond_encoder = BondEncoder(emb_dim=self.hparams.edge_emb_size)

        # init network architecture
        if self.hparams.architecture == "GCN":
            self.model = gcn.GCN(hparams=self.hparams)
        elif self.hparams.architecture == "GAT":
            self.model = gat.GAT(hparams=self.hparams)
        elif self.hparams.architecture == "GraphSAGE":
            self.model = graph_sage.GraphSAGE(hparams=self.hparams)
        elif self.hparams.architecture == "GIN":
            self.model = gin.GIN(hparams=self.hparams)
        else:
            raise Exception("Incorrect architecture name!")

        # loss function
        self.criterion = torch.nn.BCEWithLogitsLoss()

        # metric
        self.evaluator = Evaluator(name="ogbg-molpcba")

        self.metric_hist = {
            "train/ap": [],
            "val/ap": [],
            "train/loss": [],
            "val/loss": [],
        }

    def forward(self, batch: Any):
        batch.x = self.atom_encoder(batch.x)
        # batch.edge_attr = self.bond_encoder(batch.edge_attr)
        return self.model(batch)

    def step(self, batch: Any):
        y_pred = self.forward(batch)
        is_labeled_idx = batch.y == batch.y
        loss = self.criterion(y_pred[is_labeled_idx],
                              batch.y.to(torch.float32)[is_labeled_idx])
        y_true = batch.y.view(y_pred.shape)
        return loss, y_pred, y_true

    def training_step(self, batch: Any, batch_idx: int):
        loss, y_pred, y_true = self.step(batch)
        self.log("train/loss",
                 loss,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=False)

        # log number of NaNs
        y_true_nans = torch.sum(batch.y != batch.y)
        self.log(
            "y_true_NaNs",
            y_true_nans,
            reduce_fx=torch.sum,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
        )

        return {"loss": loss, "y_pred": y_pred, "y_true": y_true}

    def training_epoch_end(self, outputs: List[Any]):
        ap = self.calculate_metric(outputs)
        self.metric_hist["train/ap"].append(ap)
        self.metric_hist["train/loss"].append(
            self.trainer.callback_metrics["train/loss"])
        self.log("train/ap", ap, prog_bar=True)
        self.log("train/ap_best",
                 max(self.metric_hist["train/ap"]),
                 prog_bar=True)
        self.log("train/loss_best",
                 min(self.metric_hist["train/loss"]),
                 prog_bar=False)

    def validation_step(self, batch: Any, batch_idx: int):
        loss, y_pred, y_true = self.step(batch)
        self.log("val/loss",
                 loss,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=False)
        return {"loss": loss, "y_pred": y_pred, "y_true": y_true}

    def validation_epoch_end(self, outputs: List[Any]):
        ap = self.calculate_metric(outputs)
        self.metric_hist["val/ap"].append(ap)
        self.metric_hist["val/loss"].append(
            self.trainer.callback_metrics["val/loss"])
        self.log("val/ap", ap, prog_bar=True)
        self.log("val/ap_best", max(self.metric_hist["val/ap"]), prog_bar=True)
        self.log("val/loss_best",
                 min(self.metric_hist["val/loss"]),
                 prog_bar=False)

    def test_step(self, batch: Any, batch_idx: int):
        loss, y_pred, y_true = self.step(batch)
        self.log("test/loss",
                 loss,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=False)
        return {"loss": loss, "y_pred": y_pred, "y_true": y_true}

    def test_epoch_end(self, outputs: List[Any]):
        self.log("test/ap", self.calculate_metric(outputs), prog_bar=False)

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)

    def calculate_metric(self, outputs: List[Any]):
        y_true = torch.cat([x["y_true"] for x in outputs], dim=0)
        y_pred = torch.cat([x["y_pred"] for x in outputs], dim=0)
        result_dict = self.evaluator.eval({"y_true": y_true, "y_pred": y_pred})
        return result_dict["ap"]
Beispiel #16
0
class Trainer(object):

    def __init__(self, args):

        super(Trainer, self).__init__()

        # Random Seed
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(args.seed)
            torch.cuda.manual_seed_all(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        self.args = args
        self.exp_name = self.set_experiment_name()

        self.use_cuda = args.gpu >= 0 and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(args.gpu)
            self.args.device = 'cuda:{}'.format(args.gpu)
        else:
            self.args.device = 'cpu'

        self.dataset = self.load_data()

        self.evaluator = Evaluator(args.data)

    def load_data(self):

        dataset = PygGraphPropPredDataset(name = self.args.data)
        self.args.task_type, self.args.num_features, self.args.num_classes, self.args.avg_num_nodes \
            = dataset.task_type, dataset.num_features, dataset.num_tasks, np.ceil(np.mean([data.num_nodes for data in dataset]))
        print('# %s: [Task]-%s [FEATURES]-%d [NUM_CLASSES]-%d [AVG_NODES]-%d' % (dataset, self.args.task_type, self.args.num_features, self.args.num_classes, self.args.avg_num_nodes))

        return dataset

    def load_dataloader(self):

        split_idx = self.dataset.get_idx_split()

        train_loader = DataLoader(self.dataset[split_idx["train"]], batch_size=self.args.batch_size, shuffle=True)
        val_loader = DataLoader(self.dataset[split_idx["valid"]], batch_size=self.args.batch_size, shuffle=False)
        test_loader = DataLoader(self.dataset[split_idx["test"]], batch_size=self.args.batch_size, shuffle=False)

        return train_loader, val_loader, test_loader

    def load_model(self):

        if self.args.model == 'GMT':

            model = GraphMultisetTransformer_for_OGB(self.args)

        else:

            raise ValueError("Model Name <{}> is Unknown".format(self.args.model))

        if self.use_cuda:

            model.to(self.args.device)

        return model

    def set_log(self):

        self.train_curve = []
        self.valid_curve = []
        self.test_curve = []

        logger = Logger(str(os.path.join('./logs/{}/'.format(self.log_folder_name), 'experiment-{}_seed-{}.log'.format(self.exp_name, self.args.seed))), mode='a')

        t_start = time.perf_counter()

        return logger, t_start

    def organize_log(self, logger, train_perf, valid_perf, test_perf, train_loss, epoch):

        self.train_curve.append(train_perf[self.dataset.eval_metric])
        self.valid_curve.append(valid_perf[self.dataset.eval_metric])
        self.test_curve.append(test_perf[self.dataset.eval_metric])

        logger.log("[Val: Epoch %d] (Loss) Loss: %.4f Train: %.4f%% Valid: %.4f%% Test: %.4f%% " % (
            epoch, train_loss, self.train_curve[-1], self.valid_curve[-1], self.test_curve[-1]))

    def train(self):

        wandb.init(entity='samjkwong', project='gmt')

        train_loader, val_loader, test_loader = self.load_dataloader()

        # Load Model & Optimizer
        self.model = self.load_model()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.args.lr, weight_decay = self.args.weight_decay)

        self.cls_criterion = torch.nn.BCEWithLogitsLoss()
        self.reg_criterion = torch.nn.MSELoss()

        if self.args.lr_schedule:
            self.scheduler = get_cosine_schedule_with_warmup(self.optimizer, self.args.patience * len(train_loader), self.args.num_epochs * len(train_loader))

        logger, t_start = self.set_log()

        for epoch in trange(0, (self.args.num_epochs), desc = '[Epoch]', position = 1):

            self.model.train()
            total_loss = 0

            for _, data in enumerate(tqdm(train_loader, desc="[Iteration]")):

                if data.x.shape[0] == 1 or data.batch[-1] == 0: pass

                self.optimizer.zero_grad()
                data = data.to(self.args.device)
                out = self.model(data)

                is_labeled = data.y == data.y

                if "classification" in self.args.task_type: 
                    loss = self.cls_criterion(out.to(torch.float32)[is_labeled], data.y.to(torch.float32)[is_labeled])
                else:
                    loss = self.reg_criterion(out.to(torch.float32)[is_labeled], data.y.to(torch.float32)[is_labeled])

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm)
                total_loss += loss.item() * num_graphs(data)
                self.optimizer.step()

                if self.args.lr_schedule:
                    self.scheduler.step()

            total_loss = total_loss / len(train_loader.dataset)

            train_perf, valid_perf, test_perf = self.eval(train_loader), self.eval(val_loader), self.eval(test_loader)
            self.organize_log(logger, train_perf, valid_perf, test_perf, total_loss, epoch)

            # WANDB logging
            wandb.log({
                'Epoch': epoch,
                'Train Loss': total_loss,
                'Train ROC-AUC': train_perf,
                'Val ROC-AUC': valid_perf,
                'Test ROC-AUC': test_perf
            })

        t_end = time.perf_counter()

        if 'classification' in self.dataset.task_type:
            best_val_epoch = np.argmax(np.array(self.valid_curve))
            best_train = max(self.train_curve)
        else:
            best_val_epoch = np.argmin(np.array(self.valid_curve))
            best_train = min(self.train_curve)

        best_val = self.valid_curve[best_val_epoch]
        test_score = self.test_curve[best_val_epoch]

        logger.log("Train: {} Valid: {} Test: {} with Time: {}".format(best_train, best_val, test_score, (t_end - t_start)))

        result_file = "./results/{}/{}-results.txt".format(self.log_folder_name, self.exp_name)
        with open(result_file, 'a+') as f:
            f.write("{}: {} {} {} {}\n".format(self.args.seed, best_train, self.train_curve[best_val_epoch], best_val, test_score))

        torch.save({
            'model_state_dict': self.model.state_dict(),
            'Val': best_val,
            'Train': self.train_curve[best_val_epoch],
            'Test': test_score,
            'BestTrain': best_train
            }, './checkpoints/{}/best-model_{}.pth'.format(self.log_folder_name, self.args.seed))

    def eval(self, loader):

        self.model.eval()

        y_true = []
        y_pred = []

        for _, batch in enumerate(tqdm(loader, desc="[Iteration]")):
            batch = batch.to(self.args.device)

            if batch.x.shape[0] == 1: pass

            with torch.no_grad():
                pred = self.model(batch)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

        y_true = torch.cat(y_true, dim = 0).numpy()
        y_pred = torch.cat(y_pred, dim = 0).numpy()

        input_dict = {"y_true": y_true, "y_pred": y_pred}

        return self.evaluator.eval(input_dict)

    def set_experiment_name(self):

        ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

        self.log_folder_name = os.path.join(*[self.args.data, self.args.model, self.args.experiment_number])

        if not(os.path.isdir('./checkpoints/{}'.format(self.log_folder_name))):
            os.makedirs(os.path.join('./checkpoints/{}'.format(self.log_folder_name)))

        if not(os.path.isdir('./results/{}'.format(self.log_folder_name))):
            os.makedirs(os.path.join('./results/{}'.format(self.log_folder_name)))

        if not(os.path.isdir('./logs/{}'.format(self.log_folder_name))):
            os.makedirs(os.path.join('./logs/{}'.format(self.log_folder_name)))

        print("Make Directory {} in Logs, Checkpoints and Results Folders".format(self.log_folder_name))

        exp_name = str()
        exp_name += "CV={}_".format(self.args.conv)
        exp_name += "NC={}_".format(self.args.num_convs)
        exp_name += "MC={}_".format(self.args.mab_conv)
        exp_name += "MS={}_".format(self.args.model_string)
        exp_name += "BS={}_".format(self.args.batch_size)
        exp_name += "LR={}_".format(self.args.lr)
        exp_name += "WD={}_".format(self.args.weight_decay)
        exp_name += "GN={}_".format(self.args.grad_norm)
        exp_name += "DO={}_".format(self.args.dropout)
        exp_name += "HD={}_".format(self.args.num_hidden)
        exp_name += "NH={}_".format(self.args.num_heads)
        exp_name += "PL={}_".format(self.args.pooling_ratio)
        exp_name += "LN={}_".format(self.args.ln)
        exp_name += "LS={}_".format(self.args.lr_schedule)
        exp_name += "CS={}_".format(self.args.cluster)
        exp_name += "TS={}".format(ts)

        # Save training arguments for reproduction
        torch.save(self.args, os.path.join('./checkpoints/{}'.format(self.log_folder_name), 'training_args.bin'))

        return exp_name
class OGBPCBADataset(WILDSDataset):
    """
    The OGB-molpcba dataset.
    This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet.

    Supported `split_scheme`:
        'official' or 'scaffold', which are equivalent

    Input (x):
        Molecular graphs represented as Pytorch Geometric data objects

    Label (y):
        y represents 128-class binary labels.

    Metadata:
        - scaffold
            Each molecule is annotated with the scaffold ID that the molecule is assigned to.

    Website:
        https://ogb.stanford.edu/docs/graphprop/#ogbg-mol

    Original publication:
        @article{hu2020ogb,
            title={Open Graph Benchmark: Datasets for Machine Learning on Graphs},
            author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}},
            journal={arXiv preprint arXiv:2005.00687},
            year={2020}
        }

        @article{wu2018moleculenet,
            title={MoleculeNet: a benchmark for molecular machine learning},
            author={Z. {Wu}, B. {Ramsundar}, E. V {Feinberg}, J. {Gomes}, C. {Geniesse}, A. S {Pappu}, K. {Leswing}, V. {Pande}},
            journal={Chemical science},
            volume={9},
            number={2},
            pages={513--530},
            year={2018},
            publisher={Royal Society of Chemistry}
        }

    License:
        This dataset is distributed under the MIT license.
        https://github.com/snap-stanford/ogb/blob/master/LICENSE
    """
    def __init__(self,
                 root_dir='data',
                 download=False,
                 split_scheme='official'):
        # internally call ogb package
        self.ogb_dataset = PygGraphPropPredDataset(name='ogbg-molpcba',
                                                   root=root_dir)

        # set variables
        self._dataset_name = 'ogbg-molpcba'
        self._data_dir = self.ogb_dataset.root
        if split_scheme == 'official':
            split_scheme = 'scaffold'
        self._split_scheme = split_scheme
        self._y_type = 'float'  # although the task is binary classification, the prediction target contains nan value, thus we need float
        self._y_size = self.ogb_dataset.num_tasks
        self._n_classes = self.ogb_dataset.__num_classes__

        self._split_array = torch.zeros(len(self.ogb_dataset)).long()
        split_idx = self.ogb_dataset.get_idx_split()
        self._split_array[split_idx['train']] = 0
        self._split_array[split_idx['valid']] = 1
        self._split_array[split_idx['test']] = 2

        self._y_array = self.ogb_dataset.data.y

        self._metadata_fields = ['scaffold']

        metadata_file_path = os.path.join(self.ogb_dataset.root, 'raw',
                                          'scaffold_group.npy')
        if not os.path.exists(metadata_file_path):
            download_url('', os.path.join(self.ogb_dataset.root, 'raw'))
        self._metadata_array = torch.from_numpy(
            np.load(metadata_file_path)).reshape(-1, 1).long()
        self._collate = PyGCollater(follow_batch=[])

        self._metric = Evaluator('ogbg-molpcba')

        super().__init__(root_dir, download, split_scheme)

    def get_input(self, idx):
        return self.ogb_dataset[int(idx)]

    def eval(self, y_pred, y_true, metadata):
        input_dict = {"y_true": y_true, "y_pred": y_pred}
        results = self._metric.eval(input_dict)

        return results, f"Average precision: {results['ap']:.3f}\n"