def test(model, data, test_train_links, saver, fold_num): print("testing...") fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num) pairs, loss = evaluate(model, data, test_train_links, saver, test=True) eval = Eval(model, data, pairs, set_name="test", saver=saver) res = eval.eval(fold_str=fold_str) if COMET_EXPERIMENT: with COMET_EXPERIMENT.test(): COMET_EXPERIMENT.send_notification(saver.get_f_name(), status="finished", additional_data=res)
def train(train_data, val_data, saver): train_data.init_node_feats(FLAGS.init_type, FLAGS.device) val_data.init_node_feats(FLAGS.init_type, FLAGS.device) model = create_model(train_data) model = model.to(FLAGS.device) pytorch_total_params = sum(p.numel() for p in model.parameters()) print("Number params: ", pytorch_total_params) moving_avg = MovingAverage(FLAGS.validation_window_size, FLAGS.validation_metric != 'loss') pyg_graph = train_data.get_pyg_graph(FLAGS.device) optimizer = torch.optim.Adam( model.parameters(), lr=FLAGS.lr, ) for epoch in range(FLAGS.num_epochs): t = time.time() model.train() model.zero_grad() loss, preds_train = model(pyg_graph, train_data) loss.backward() optimizer.step() loss = loss.item() if COMET_EXPERIMENT: COMET_EXPERIMENT.log_metric("loss", loss, epoch + 1) with torch.no_grad(): val_loss, preds_val = model(pyg_graph, val_data) val_loss = val_loss.item() eval_res_val = eval(preds_val, val_data) print("Epoch: {:04d}, Train Loss: {:.5f}, Time: {:.5f}".format( epoch, loss, time.time() - t)) print("Val Loss: {:.5f}".format(val_loss)) print("Val Results: ...") pprint(eval_res_val) eval_res_val["loss"] = val_loss if COMET_EXPERIMENT: COMET_EXPERIMENT.log_metrics(eval_res_val, prefix="validation", step=epoch + 1) if len(moving_avg.results) == 0 or moving_avg.best_result( eval_res_val[FLAGS.validation_metric]): saver.save_trained_model(model, epoch + 1) moving_avg.add_to_moving_avg(eval_res_val[FLAGS.validation_metric]) if moving_avg.stop(): break best_model = saver.load_trained_model(train_data) return best_model, model
def train(train_data, val_data, val_pairs, saver, fold_num=None): print('creating models...') model = Model(train_data) model = model.to(FLAGS.device) print(model) if "model_init" in FLAGS.init_embds: print('initial embedding models:') print(model.init_layers) _get_initial_embd(train_data, model) train_data.dataset.init_interaction_graph_embds(device=FLAGS.device) val_data.dataset.init_interaction_graph_embds(device=FLAGS.device) val_pairs = list(val_pairs) random.shuffle(val_pairs) val_pairs = torch.stack(val_pairs) saver.log_model_architecture(model) model.train_data = train_data optimizer = torch.optim.Adam( model.parameters(), lr=FLAGS.lr, ) num_iters_total = 0 if COMET_EXPERIMENT: with COMET_EXPERIMENT.train(): val_results = _train(num_iters_total, train_data, val_data, val_pairs, model, optimizer, saver, fold_num) else: val_results = _train(num_iters_total, train_data, val_data, val_pairs, model, optimizer, saver, fold_num) return model, val_results
def eval(self, round=None, fold_str=''): if round is None: info = 'final_test' d = OrderedDict() self.global_result_dict[info] = d else: raise NotImplementedError() d['pairwise'], supplement = self._eval_pairs(info) if self.saver: self.saver.save_global_eval_result_dict(self.global_result_dict) if COMET_EXPERIMENT: with COMET_EXPERIMENT.test(): COMET_EXPERIMENT.log_metrics( self.global_result_dict['final_test']['pairwise'], prefix=fold_str) confusion_matrix = self.global_result_dict['final_test'][ 'pairwise'].get('confusion_matrix') if 'fine_grained_by_degree' in supplement: for k, v in supplement['fine_grained_by_degree'].items(): COMET_EXPERIMENT.log_metrics(v, prefix=fold_str + 'degree_bin_' + str(k)) if confusion_matrix is not None: labels = [ k for k, v in sorted(self.eval_data.dataset. interaction_edge_labels.items(), key=lambda item: item[1]) ] COMET_EXPERIMENT.log_confusion_matrix( matrix=confusion_matrix, labels=labels) return d
def main(): saver = Saver() train_data, val_data, test_data, raw_doc_list = load_data() print(train_data.graph.shape) if COMET_EXPERIMENT: with COMET_EXPERIMENT.train(): saved_model, model = train(train_data, val_data, saver) else: saved_model, model = train(train_data, val_data, saver) with torch.no_grad(): test_loss_model, preds_model = model( train_data.get_pyg_graph(device=FLAGS.device), test_data) eval_res = eval(preds_model, test_data, True) y_true = eval_res.pop('y_true') y_pred = eval_res.pop('y_pred') print("Test...") pprint(eval_res) if COMET_EXPERIMENT: from comet_ml.utils import ConfusionMatrix def index_to_example(index): test_docs_ids = test_data.node_ids return raw_doc_list[test_docs_ids[index]] confusion_matrix = ConfusionMatrix( index_to_example_function=index_to_example, labels=list(test_data.label_dict.keys())) confusion_matrix.compute_matrix(y_true, y_pred) with COMET_EXPERIMENT.test(): COMET_EXPERIMENT.log_metrics(eval_res) COMET_EXPERIMENT.log_confusion_matrix( matrix=confusion_matrix, labels=list(test_data.label_dict.keys()))
def main(): tvt_pairs_dict = load_pair_tvt_splits() orig_dataset = load_dataset(FLAGS.dataset, 'all', FLAGS.node_feats, FLAGS.edge_feats) orig_dataset, num_node_feat = encode_node_features(dataset=orig_dataset) num_interaction_edge_feat = encode_edge_features( orig_dataset.interaction_combo_nxgraph, FLAGS.hyper_eatts) for i, (train_pairs, val_pairs, test_pairs) in \ enumerate(zip(tvt_pairs_dict['train'], tvt_pairs_dict['val'], tvt_pairs_dict['test'])): fold_num = i + 1 if FLAGS.cross_val and FLAGS.run_only_on_fold != -1 and FLAGS.run_only_on_fold != fold_num: continue set_seed(FLAGS.random_seed + 5) print(f'======== FOLD {fold_num} ========') saver = Saver(fold=fold_num) dataset = deepcopy(orig_dataset) train_data, val_data, test_data, val_pairs, test_pairs, _ = \ load_pairs_to_dataset(num_node_feat, num_interaction_edge_feat, train_pairs, val_pairs, test_pairs, dataset) print('========= Training... ========') if FLAGS.load_model is not None: print('loading models: {}'.format(FLAGS.load_model)) trained_model = Model(train_data).to(FLAGS.device) trained_model.load_state_dict(torch.load( FLAGS.load_model, map_location=FLAGS.device), strict=False) print('models loaded') print(trained_model) else: train(train_data, val_data, val_pairs, saver, fold_num=fold_num) trained_model = saver.load_trained_model(train_data) if FLAGS.save_model: saver.save_trained_model(trained_model) print('======== Testing... ========') if FLAGS.lower_level_layers and FLAGS.higher_level_layers: _get_initial_embd(test_data, trained_model) test_data.dataset.init_interaction_graph_embds(device=FLAGS.device) elif FLAGS.higher_level_layers and not FLAGS.lower_level_layers: test_data.dataset.init_interaction_graph_embds(device=FLAGS.device) if FLAGS.save_final_node_embeddings and 'gmn' not in FLAGS.model: with torch.no_grad(): trained_model = trained_model.to(FLAGS.device) trained_model.eval() if FLAGS.higher_level_layers: batch_data = model_forward(trained_model, test_data, is_train=False) trained_model.use_layers = "higher_no_eval_layers" outs = trained_model(batch_data) else: outs = _get_initial_embd(test_data, trained_model) trained_model.use_layers = 'all' saver.save_graph_embeddings_mat(outs.cpu().detach().numpy(), test_data.dataset.id_map, test_data.dataset.gs_map) if FLAGS.higher_level_layers: batch_data.restore_interaction_nxgraph() test(trained_model, test_data, test_pairs, saver, fold_num) overall_time = convert_long_time_to_str(time() - t) print(overall_time) print(saver.get_log_dir()) print(basename(saver.get_log_dir())) saver.save_overall_time(overall_time) saver.close() if FLAGS.cross_val and COMET_EXPERIMENT: results = aggregate_comet_results_from_folds( COMET_EXPERIMENT, FLAGS.num_folds, FLAGS.dataset, FLAGS.eval_performance_by_degree) COMET_EXPERIMENT.log_metrics(results, prefix='aggr')
def _train(num_iters_total, train_data, val_data, train_val_links, model, optimizer, saver, fold_num, retry_num=0): fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num) fold_str = fold_str + 'retry_{}_'.format( retry_num) if retry_num > 0 else fold_str if fold_str == '': print("here") epoch_timer = Timer() total_loss = 0 curr_num_iters = 0 val_results = {} if FLAGS.sampler == "neighbor_sampler": sampler = NeighborSampler(train_data, FLAGS.num_neighbors_sample, FLAGS.batch_size) estimated_iters_per_epoch = ceil( (len(train_data.dataset.gs_map) / FLAGS.batch_size)) elif FLAGS.sampler == "random_sampler": sampler = RandomSampler(train_data, FLAGS.batch_size, FLAGS.sample_induced) estimated_iters_per_epoch = ceil( (len(train_data.dataset.train_pairs) / FLAGS.batch_size)) else: sampler = EverythingSampler(train_data) estimated_iters_per_epoch = 1 moving_avg = MovingAverage(FLAGS.validation_window_size) iters_per_validation = FLAGS.iters_per_validation \ if FLAGS.iters_per_validation != -1 else estimated_iters_per_epoch for iter in range(FLAGS.num_iters): model.train() model.zero_grad() batch_data = model_forward(model, train_data, sampler=sampler) loss = _train_iter(batch_data, model, optimizer) batch_data.restore_interaction_nxgraph() total_loss += loss num_iters_total_limit = FLAGS.num_iters curr_num_iters += 1 if num_iters_total_limit is not None and \ num_iters_total == num_iters_total_limit: break if iter % FLAGS.print_every_iters == 0: saver.log_tvt_info("{}Iter {:04d}, Loss: {:.7f}".format( fold_str, iter + 1, loss)) if COMET_EXPERIMENT: COMET_EXPERIMENT.log_metric("{}loss".format(fold_str), loss, iter + 1) if (iter + 1) % iters_per_validation == 0: eval_res, supplement = validation( model, val_data, train_val_links, saver, max_num_examples=FLAGS.max_eval_pairs) epoch = iter / estimated_iters_per_epoch saver.log_tvt_info('{}Estimated Epoch: {:05f}, Loss: {:.7f} ' '({} iters)\t\t{}\n Val Result: {}'.format( fold_str, epoch, eval_res["Loss"], curr_num_iters, epoch_timer.time_and_clear(), eval_res)) if COMET_EXPERIMENT: COMET_EXPERIMENT.log_metrics( eval_res, prefix="{}validation".format(fold_str), step=iter + 1) COMET_EXPERIMENT.log_histogram_3d( supplement['y_pred'], name="{}y_pred".format(fold_str), step=iter + 1) COMET_EXPERIMENT.log_histogram_3d( supplement['y_true'], name='{}y_true'.format(fold_str), step=iter + 1) confusion_matrix = supplement.get('confusion_matrix') if confusion_matrix is not None: labels = [ k for k, v in sorted( batch_data.dataset.interaction_edge_labels.items(), key=lambda item: item[1]) ] COMET_EXPERIMENT.log_confusion_matrix( matrix=confusion_matrix, labels=labels, step=iter + 1) curr_num_iters = 0 val_results[iter + 1] = eval_res if len(moving_avg.results) == 0 or ( eval_res[FLAGS.validation_metric] - 1e-7) > max( moving_avg.results): saver.save_trained_model(model, iter + 1) moving_avg.add_to_moving_avg(eval_res[FLAGS.validation_metric]) if moving_avg.stop(): break return val_results