def _select_with_gids(self, want_gids):
     t = Timer()
     want_gids = set(want_gids)
     graphs = [g for g in self.gs if g.gid() in want_gids]
     print('Done graphs', t.time_and_clear())
     pairs = {}
     for (gid1, gid2), pair in self.pairs.items():
         # Both g1 and g2 need to be in the (one) train/test/... set.
         if gid1 in want_gids and gid2 in want_gids:
             pairs[(gid1, gid2)] = pair
     print('Done pairs', t.time_and_clear())
     return graphs, pairs
def evaluate(model,
             data,
             eval_links,
             saver,
             max_num_examples=None,
             test=False):
    with torch.no_grad():
        model = model.to(FLAGS.device)
        model.eval()
        total_loss = 0
        all_pair_list = []
        iter_timer = Timer()
        eval_dataset = torch.utils.data.dataset.TensorDataset(eval_links)
        data_loader = DataLoader(eval_dataset,
                                 batch_size=FLAGS.batch_size,
                                 shuffle=True)

        for iter, batch_gids in enumerate(data_loader):
            if max_num_examples and len(all_pair_list) >= max_num_examples:
                break

            batch_gids = batch_gids[0]
            if len(batch_gids) == 0:
                continue
            batch_data = BatchData(batch_gids,
                                   data.dataset,
                                   is_train=False,
                                   unique_graphs=FLAGS.batch_unique_graphs)
            if FLAGS.lower_level_layers and FLAGS.higher_level_layers:
                if FLAGS.pair_interaction:
                    model.use_layers = 'lower_layers'
                    model(batch_data)
                model.use_layers = 'higher_layers'
            else:
                model.use_layers = 'all'
            loss = model(batch_data)

            batch_data.restore_interaction_nxgraph()
            total_loss += loss.item()
            all_pair_list.extend(batch_data.pair_list)
            if test:
                saver.log_tvt_info(
                    '\tIter: {:03d}, Test Loss: {:.7f}\t\t{}'.format(
                        iter + 1, loss, iter_timer.time_and_clear()))
    return all_pair_list, total_loss / (iter + 1)
def _train(num_iters_total,
           train_data,
           val_data,
           train_val_links,
           model,
           optimizer,
           saver,
           fold_num,
           retry_num=0):
    fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num)
    fold_str = fold_str + 'retry_{}_'.format(
        retry_num) if retry_num > 0 else fold_str
    if fold_str == '':
        print("here")
    epoch_timer = Timer()
    total_loss = 0
    curr_num_iters = 0
    val_results = {}
    if FLAGS.sampler == "neighbor_sampler":
        sampler = NeighborSampler(train_data, FLAGS.num_neighbors_sample,
                                  FLAGS.batch_size)
        estimated_iters_per_epoch = ceil(
            (len(train_data.dataset.gs_map) / FLAGS.batch_size))
    elif FLAGS.sampler == "random_sampler":
        sampler = RandomSampler(train_data, FLAGS.batch_size,
                                FLAGS.sample_induced)
        estimated_iters_per_epoch = ceil(
            (len(train_data.dataset.train_pairs) / FLAGS.batch_size))
    else:
        sampler = EverythingSampler(train_data)
        estimated_iters_per_epoch = 1

    moving_avg = MovingAverage(FLAGS.validation_window_size)
    iters_per_validation = FLAGS.iters_per_validation \
        if FLAGS.iters_per_validation != -1 else estimated_iters_per_epoch

    for iter in range(FLAGS.num_iters):
        model.train()
        model.zero_grad()
        batch_data = model_forward(model, train_data, sampler=sampler)
        loss = _train_iter(batch_data, model, optimizer)
        batch_data.restore_interaction_nxgraph()
        total_loss += loss
        num_iters_total_limit = FLAGS.num_iters
        curr_num_iters += 1
        if num_iters_total_limit is not None and \
                num_iters_total == num_iters_total_limit:
            break
        if iter % FLAGS.print_every_iters == 0:
            saver.log_tvt_info("{}Iter {:04d}, Loss: {:.7f}".format(
                fold_str, iter + 1, loss))
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metric("{}loss".format(fold_str), loss,
                                            iter + 1)
        if (iter + 1) % iters_per_validation == 0:
            eval_res, supplement = validation(
                model,
                val_data,
                train_val_links,
                saver,
                max_num_examples=FLAGS.max_eval_pairs)
            epoch = iter / estimated_iters_per_epoch
            saver.log_tvt_info('{}Estimated Epoch: {:05f}, Loss: {:.7f} '
                               '({} iters)\t\t{}\n Val Result: {}'.format(
                                   fold_str, epoch,
                                   eval_res["Loss"], curr_num_iters,
                                   epoch_timer.time_and_clear(), eval_res))
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metrics(
                    eval_res,
                    prefix="{}validation".format(fold_str),
                    step=iter + 1)
                COMET_EXPERIMENT.log_histogram_3d(
                    supplement['y_pred'],
                    name="{}y_pred".format(fold_str),
                    step=iter + 1)
                COMET_EXPERIMENT.log_histogram_3d(
                    supplement['y_true'],
                    name='{}y_true'.format(fold_str),
                    step=iter + 1)
                confusion_matrix = supplement.get('confusion_matrix')
                if confusion_matrix is not None:
                    labels = [
                        k for k, v in sorted(
                            batch_data.dataset.interaction_edge_labels.items(),
                            key=lambda item: item[1])
                    ]
                    COMET_EXPERIMENT.log_confusion_matrix(
                        matrix=confusion_matrix, labels=labels, step=iter + 1)
            curr_num_iters = 0
            val_results[iter + 1] = eval_res
            if len(moving_avg.results) == 0 or (
                    eval_res[FLAGS.validation_metric] - 1e-7) > max(
                        moving_avg.results):
                saver.save_trained_model(model, iter + 1)
            moving_avg.add_to_moving_avg(eval_res[FLAGS.validation_metric])
            if moving_avg.stop():
                break
    return val_results