def test(model, data, test_train_links, saver, fold_num):
    print("testing...")
    fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num)

    pairs, loss = evaluate(model, data, test_train_links, saver, test=True)
    eval = Eval(model, data, pairs, set_name="test", saver=saver)
    res = eval.eval(fold_str=fold_str)
    if COMET_EXPERIMENT:
        with COMET_EXPERIMENT.test():
            COMET_EXPERIMENT.send_notification(saver.get_f_name(),
                                               status="finished",
                                               additional_data=res)
Esempio n. 2
0
def train(train_data, val_data, saver):
    train_data.init_node_feats(FLAGS.init_type, FLAGS.device)
    val_data.init_node_feats(FLAGS.init_type, FLAGS.device)
    model = create_model(train_data)
    model = model.to(FLAGS.device)
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print("Number params: ", pytorch_total_params)
    moving_avg = MovingAverage(FLAGS.validation_window_size,
                               FLAGS.validation_metric != 'loss')
    pyg_graph = train_data.get_pyg_graph(FLAGS.device)
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=FLAGS.lr,
    )

    for epoch in range(FLAGS.num_epochs):
        t = time.time()
        model.train()
        model.zero_grad()
        loss, preds_train = model(pyg_graph, train_data)
        loss.backward()
        optimizer.step()
        loss = loss.item()
        if COMET_EXPERIMENT:
            COMET_EXPERIMENT.log_metric("loss", loss, epoch + 1)
        with torch.no_grad():
            val_loss, preds_val = model(pyg_graph, val_data)
            val_loss = val_loss.item()
            eval_res_val = eval(preds_val, val_data)
            print("Epoch: {:04d}, Train Loss: {:.5f}, Time: {:.5f}".format(
                epoch, loss,
                time.time() - t))
            print("Val Loss: {:.5f}".format(val_loss))
            print("Val Results: ...")
            pprint(eval_res_val)
            eval_res_val["loss"] = val_loss
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metrics(eval_res_val,
                                             prefix="validation",
                                             step=epoch + 1)

            if len(moving_avg.results) == 0 or moving_avg.best_result(
                    eval_res_val[FLAGS.validation_metric]):
                saver.save_trained_model(model, epoch + 1)
            moving_avg.add_to_moving_avg(eval_res_val[FLAGS.validation_metric])
            if moving_avg.stop():
                break
    best_model = saver.load_trained_model(train_data)
    return best_model, model
def train(train_data, val_data, val_pairs, saver, fold_num=None):
    print('creating models...')
    model = Model(train_data)
    model = model.to(FLAGS.device)
    print(model)

    if "model_init" in FLAGS.init_embds:
        print('initial embedding models:')
        print(model.init_layers)
        _get_initial_embd(train_data, model)

    train_data.dataset.init_interaction_graph_embds(device=FLAGS.device)
    val_data.dataset.init_interaction_graph_embds(device=FLAGS.device)

    val_pairs = list(val_pairs)
    random.shuffle(val_pairs)
    val_pairs = torch.stack(val_pairs)
    saver.log_model_architecture(model)
    model.train_data = train_data
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=FLAGS.lr,
    )
    num_iters_total = 0
    if COMET_EXPERIMENT:
        with COMET_EXPERIMENT.train():
            val_results = _train(num_iters_total, train_data, val_data,
                                 val_pairs, model, optimizer, saver, fold_num)
    else:
        val_results = _train(num_iters_total, train_data, val_data, val_pairs,
                             model, optimizer, saver, fold_num)

    return model, val_results
    def eval(self, round=None, fold_str=''):
        if round is None:
            info = 'final_test'
            d = OrderedDict()
            self.global_result_dict[info] = d
        else:
            raise NotImplementedError()

        d['pairwise'], supplement = self._eval_pairs(info)

        if self.saver:
            self.saver.save_global_eval_result_dict(self.global_result_dict)
        if COMET_EXPERIMENT:
            with COMET_EXPERIMENT.test():
                COMET_EXPERIMENT.log_metrics(
                    self.global_result_dict['final_test']['pairwise'],
                    prefix=fold_str)
                confusion_matrix = self.global_result_dict['final_test'][
                    'pairwise'].get('confusion_matrix')
                if 'fine_grained_by_degree' in supplement:
                    for k, v in supplement['fine_grained_by_degree'].items():
                        COMET_EXPERIMENT.log_metrics(v,
                                                     prefix=fold_str +
                                                     'degree_bin_' + str(k))
                if confusion_matrix is not None:
                    labels = [
                        k for k, v in sorted(self.eval_data.dataset.
                                             interaction_edge_labels.items(),
                                             key=lambda item: item[1])
                    ]
                    COMET_EXPERIMENT.log_confusion_matrix(
                        matrix=confusion_matrix, labels=labels)
        return d
Esempio n. 5
0
def main():
    saver = Saver()
    train_data, val_data, test_data, raw_doc_list = load_data()

    print(train_data.graph.shape)
    if COMET_EXPERIMENT:
        with COMET_EXPERIMENT.train():
            saved_model, model = train(train_data, val_data, saver)
    else:
        saved_model, model = train(train_data, val_data, saver)
    with torch.no_grad():
        test_loss_model, preds_model = model(
            train_data.get_pyg_graph(device=FLAGS.device), test_data)
    eval_res = eval(preds_model, test_data, True)
    y_true = eval_res.pop('y_true')
    y_pred = eval_res.pop('y_pred')
    print("Test...")
    pprint(eval_res)
    if COMET_EXPERIMENT:
        from comet_ml.utils import ConfusionMatrix

        def index_to_example(index):
            test_docs_ids = test_data.node_ids
            return raw_doc_list[test_docs_ids[index]]

        confusion_matrix = ConfusionMatrix(
            index_to_example_function=index_to_example,
            labels=list(test_data.label_dict.keys()))
        confusion_matrix.compute_matrix(y_true, y_pred)

        with COMET_EXPERIMENT.test():
            COMET_EXPERIMENT.log_metrics(eval_res)
            COMET_EXPERIMENT.log_confusion_matrix(
                matrix=confusion_matrix,
                labels=list(test_data.label_dict.keys()))
def main():
    tvt_pairs_dict = load_pair_tvt_splits()

    orig_dataset = load_dataset(FLAGS.dataset, 'all', FLAGS.node_feats,
                                FLAGS.edge_feats)

    orig_dataset, num_node_feat = encode_node_features(dataset=orig_dataset)
    num_interaction_edge_feat = encode_edge_features(
        orig_dataset.interaction_combo_nxgraph, FLAGS.hyper_eatts)

    for i, (train_pairs, val_pairs, test_pairs) in \
            enumerate(zip(tvt_pairs_dict['train'],
                          tvt_pairs_dict['val'],
                          tvt_pairs_dict['test'])):
        fold_num = i + 1
        if FLAGS.cross_val and FLAGS.run_only_on_fold != -1 and FLAGS.run_only_on_fold != fold_num:
            continue

        set_seed(FLAGS.random_seed + 5)
        print(f'======== FOLD {fold_num} ========')
        saver = Saver(fold=fold_num)
        dataset = deepcopy(orig_dataset)
        train_data, val_data, test_data, val_pairs, test_pairs, _ = \
            load_pairs_to_dataset(num_node_feat, num_interaction_edge_feat,
                                  train_pairs, val_pairs, test_pairs,
                                  dataset)
        print('========= Training... ========')
        if FLAGS.load_model is not None:
            print('loading models: {}'.format(FLAGS.load_model))
            trained_model = Model(train_data).to(FLAGS.device)
            trained_model.load_state_dict(torch.load(
                FLAGS.load_model, map_location=FLAGS.device),
                                          strict=False)
            print('models loaded')
            print(trained_model)
        else:
            train(train_data, val_data, val_pairs, saver, fold_num=fold_num)
            trained_model = saver.load_trained_model(train_data)
            if FLAGS.save_model:
                saver.save_trained_model(trained_model)

        print('======== Testing... ========')

        if FLAGS.lower_level_layers and FLAGS.higher_level_layers:
            _get_initial_embd(test_data, trained_model)
            test_data.dataset.init_interaction_graph_embds(device=FLAGS.device)
        elif FLAGS.higher_level_layers and not FLAGS.lower_level_layers:
            test_data.dataset.init_interaction_graph_embds(device=FLAGS.device)

        if FLAGS.save_final_node_embeddings and 'gmn' not in FLAGS.model:
            with torch.no_grad():
                trained_model = trained_model.to(FLAGS.device)
                trained_model.eval()
                if FLAGS.higher_level_layers:
                    batch_data = model_forward(trained_model,
                                               test_data,
                                               is_train=False)
                    trained_model.use_layers = "higher_no_eval_layers"
                    outs = trained_model(batch_data)
                else:
                    outs = _get_initial_embd(test_data, trained_model)
                    trained_model.use_layers = 'all'

            saver.save_graph_embeddings_mat(outs.cpu().detach().numpy(),
                                            test_data.dataset.id_map,
                                            test_data.dataset.gs_map)
            if FLAGS.higher_level_layers:
                batch_data.restore_interaction_nxgraph()

        test(trained_model, test_data, test_pairs, saver, fold_num)
        overall_time = convert_long_time_to_str(time() - t)
        print(overall_time)
        print(saver.get_log_dir())
        print(basename(saver.get_log_dir()))
        saver.save_overall_time(overall_time)
        saver.close()
    if FLAGS.cross_val and COMET_EXPERIMENT:
        results = aggregate_comet_results_from_folds(
            COMET_EXPERIMENT, FLAGS.num_folds, FLAGS.dataset,
            FLAGS.eval_performance_by_degree)
        COMET_EXPERIMENT.log_metrics(results, prefix='aggr')
def _train(num_iters_total,
           train_data,
           val_data,
           train_val_links,
           model,
           optimizer,
           saver,
           fold_num,
           retry_num=0):
    fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num)
    fold_str = fold_str + 'retry_{}_'.format(
        retry_num) if retry_num > 0 else fold_str
    if fold_str == '':
        print("here")
    epoch_timer = Timer()
    total_loss = 0
    curr_num_iters = 0
    val_results = {}
    if FLAGS.sampler == "neighbor_sampler":
        sampler = NeighborSampler(train_data, FLAGS.num_neighbors_sample,
                                  FLAGS.batch_size)
        estimated_iters_per_epoch = ceil(
            (len(train_data.dataset.gs_map) / FLAGS.batch_size))
    elif FLAGS.sampler == "random_sampler":
        sampler = RandomSampler(train_data, FLAGS.batch_size,
                                FLAGS.sample_induced)
        estimated_iters_per_epoch = ceil(
            (len(train_data.dataset.train_pairs) / FLAGS.batch_size))
    else:
        sampler = EverythingSampler(train_data)
        estimated_iters_per_epoch = 1

    moving_avg = MovingAverage(FLAGS.validation_window_size)
    iters_per_validation = FLAGS.iters_per_validation \
        if FLAGS.iters_per_validation != -1 else estimated_iters_per_epoch

    for iter in range(FLAGS.num_iters):
        model.train()
        model.zero_grad()
        batch_data = model_forward(model, train_data, sampler=sampler)
        loss = _train_iter(batch_data, model, optimizer)
        batch_data.restore_interaction_nxgraph()
        total_loss += loss
        num_iters_total_limit = FLAGS.num_iters
        curr_num_iters += 1
        if num_iters_total_limit is not None and \
                num_iters_total == num_iters_total_limit:
            break
        if iter % FLAGS.print_every_iters == 0:
            saver.log_tvt_info("{}Iter {:04d}, Loss: {:.7f}".format(
                fold_str, iter + 1, loss))
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metric("{}loss".format(fold_str), loss,
                                            iter + 1)
        if (iter + 1) % iters_per_validation == 0:
            eval_res, supplement = validation(
                model,
                val_data,
                train_val_links,
                saver,
                max_num_examples=FLAGS.max_eval_pairs)
            epoch = iter / estimated_iters_per_epoch
            saver.log_tvt_info('{}Estimated Epoch: {:05f}, Loss: {:.7f} '
                               '({} iters)\t\t{}\n Val Result: {}'.format(
                                   fold_str, epoch,
                                   eval_res["Loss"], curr_num_iters,
                                   epoch_timer.time_and_clear(), eval_res))
            if COMET_EXPERIMENT:
                COMET_EXPERIMENT.log_metrics(
                    eval_res,
                    prefix="{}validation".format(fold_str),
                    step=iter + 1)
                COMET_EXPERIMENT.log_histogram_3d(
                    supplement['y_pred'],
                    name="{}y_pred".format(fold_str),
                    step=iter + 1)
                COMET_EXPERIMENT.log_histogram_3d(
                    supplement['y_true'],
                    name='{}y_true'.format(fold_str),
                    step=iter + 1)
                confusion_matrix = supplement.get('confusion_matrix')
                if confusion_matrix is not None:
                    labels = [
                        k for k, v in sorted(
                            batch_data.dataset.interaction_edge_labels.items(),
                            key=lambda item: item[1])
                    ]
                    COMET_EXPERIMENT.log_confusion_matrix(
                        matrix=confusion_matrix, labels=labels, step=iter + 1)
            curr_num_iters = 0
            val_results[iter + 1] = eval_res
            if len(moving_avg.results) == 0 or (
                    eval_res[FLAGS.validation_metric] - 1e-7) > max(
                        moving_avg.results):
                saver.save_trained_model(model, iter + 1)
            moving_avg.add_to_moving_avg(eval_res[FLAGS.validation_metric])
            if moving_avg.stop():
                break
    return val_results