def _select_with_gids(self, want_gids): t = Timer() want_gids = set(want_gids) graphs = [g for g in self.gs if g.gid() in want_gids] print('Done graphs', t.time_and_clear()) pairs = {} for (gid1, gid2), pair in self.pairs.items(): # Both g1 and g2 need to be in the (one) train/test/... set. if gid1 in want_gids and gid2 in want_gids: pairs[(gid1, gid2)] = pair print('Done pairs', t.time_and_clear()) return graphs, pairs
def evaluate(model, data, eval_links, saver, max_num_examples=None, test=False): with torch.no_grad(): model = model.to(FLAGS.device) model.eval() total_loss = 0 all_pair_list = [] iter_timer = Timer() eval_dataset = torch.utils.data.dataset.TensorDataset(eval_links) data_loader = DataLoader(eval_dataset, batch_size=FLAGS.batch_size, shuffle=True) for iter, batch_gids in enumerate(data_loader): if max_num_examples and len(all_pair_list) >= max_num_examples: break batch_gids = batch_gids[0] if len(batch_gids) == 0: continue batch_data = BatchData(batch_gids, data.dataset, is_train=False, unique_graphs=FLAGS.batch_unique_graphs) if FLAGS.lower_level_layers and FLAGS.higher_level_layers: if FLAGS.pair_interaction: model.use_layers = 'lower_layers' model(batch_data) model.use_layers = 'higher_layers' else: model.use_layers = 'all' loss = model(batch_data) batch_data.restore_interaction_nxgraph() total_loss += loss.item() all_pair_list.extend(batch_data.pair_list) if test: saver.log_tvt_info( '\tIter: {:03d}, Test Loss: {:.7f}\t\t{}'.format( iter + 1, loss, iter_timer.time_and_clear())) return all_pair_list, total_loss / (iter + 1)
def _train(num_iters_total, train_data, val_data, train_val_links, model, optimizer, saver, fold_num, retry_num=0): fold_str = '' if fold_num is None else 'Fold_{}_'.format(fold_num) fold_str = fold_str + 'retry_{}_'.format( retry_num) if retry_num > 0 else fold_str if fold_str == '': print("here") epoch_timer = Timer() total_loss = 0 curr_num_iters = 0 val_results = {} if FLAGS.sampler == "neighbor_sampler": sampler = NeighborSampler(train_data, FLAGS.num_neighbors_sample, FLAGS.batch_size) estimated_iters_per_epoch = ceil( (len(train_data.dataset.gs_map) / FLAGS.batch_size)) elif FLAGS.sampler == "random_sampler": sampler = RandomSampler(train_data, FLAGS.batch_size, FLAGS.sample_induced) estimated_iters_per_epoch = ceil( (len(train_data.dataset.train_pairs) / FLAGS.batch_size)) else: sampler = EverythingSampler(train_data) estimated_iters_per_epoch = 1 moving_avg = MovingAverage(FLAGS.validation_window_size) iters_per_validation = FLAGS.iters_per_validation \ if FLAGS.iters_per_validation != -1 else estimated_iters_per_epoch for iter in range(FLAGS.num_iters): model.train() model.zero_grad() batch_data = model_forward(model, train_data, sampler=sampler) loss = _train_iter(batch_data, model, optimizer) batch_data.restore_interaction_nxgraph() total_loss += loss num_iters_total_limit = FLAGS.num_iters curr_num_iters += 1 if num_iters_total_limit is not None and \ num_iters_total == num_iters_total_limit: break if iter % FLAGS.print_every_iters == 0: saver.log_tvt_info("{}Iter {:04d}, Loss: {:.7f}".format( fold_str, iter + 1, loss)) if COMET_EXPERIMENT: COMET_EXPERIMENT.log_metric("{}loss".format(fold_str), loss, iter + 1) if (iter + 1) % iters_per_validation == 0: eval_res, supplement = validation( model, val_data, train_val_links, saver, max_num_examples=FLAGS.max_eval_pairs) epoch = iter / estimated_iters_per_epoch saver.log_tvt_info('{}Estimated Epoch: {:05f}, Loss: {:.7f} ' '({} iters)\t\t{}\n Val Result: {}'.format( fold_str, epoch, eval_res["Loss"], curr_num_iters, epoch_timer.time_and_clear(), eval_res)) if COMET_EXPERIMENT: COMET_EXPERIMENT.log_metrics( eval_res, prefix="{}validation".format(fold_str), step=iter + 1) COMET_EXPERIMENT.log_histogram_3d( supplement['y_pred'], name="{}y_pred".format(fold_str), step=iter + 1) COMET_EXPERIMENT.log_histogram_3d( supplement['y_true'], name='{}y_true'.format(fold_str), step=iter + 1) confusion_matrix = supplement.get('confusion_matrix') if confusion_matrix is not None: labels = [ k for k, v in sorted( batch_data.dataset.interaction_edge_labels.items(), key=lambda item: item[1]) ] COMET_EXPERIMENT.log_confusion_matrix( matrix=confusion_matrix, labels=labels, step=iter + 1) curr_num_iters = 0 val_results[iter + 1] = eval_res if len(moving_avg.results) == 0 or ( eval_res[FLAGS.validation_metric] - 1e-7) > max( moving_avg.results): saver.save_trained_model(model, iter + 1) moving_avg.add_to_moving_avg(eval_res[FLAGS.validation_metric]) if moving_avg.stop(): break return val_results