def train(self, sess, config): """Trains node classification model or joint node edge model.""" self._log('Training {} model...'.format(self.model_name)) self._log('Training parameters : \n ' + format_params(config)) epochs = config.epochs lr = config.lr patience = config.patience # best_step = self.global_step # step for patience curr_step = 0 # best metrics to select model best_val_metrics = self._init_best_metrics() best_test_metrics = self._init_best_metrics() # train the model for epoch in range(epochs): self.global_step += 1 sess.run(self.train_ops, feed_dict=self._make_feed_dict('train')) train_metrics = self._eval_model(sess, 'train') val_metrics = self._eval_model(sess, 'val') self._log('Epoch {} : lr = {:.4f} | '.format(epoch, lr) + format_metrics(train_metrics, 'train') + format_metrics(val_metrics, 'val')) # write summaries train_summary = sess.run(self.train_summary, self._make_feed_dict('train')) val_summary = sess.run(self.val_summary, self._make_feed_dict('val')) self.summary_writer.add_summary(train_summary, global_step=self.global_step) self.summary_writer.add_summary(val_summary, global_step=self.global_step) # save model checkpoint if val acc increased and val loss decreased comp = check_improve(best_val_metrics, val_metrics, self.target_metrics) if np.any(comp): if np.all(comp): # best_step = self.global_step # save_path = os.path.join(save_dir, 'model') # self.saver.save(sess, save_path, global_step=self.global_step) best_test_metrics = self._eval_model(sess, 'test') best_val_metrics = val_metrics curr_step = 0 else: curr_step += 1 if curr_step == patience: self._log('Early stopping') break self._log('\n' + '*' * 40 + ' Best model metrics ' + '*' * 40) # load best model to evaluate on test set # save_path = os.path.join(save_dir, 'model-{}'.format(best_step)) # self.restore_checkpoint(sess, save_path) self._log(format_metrics(best_val_metrics, 'val')) self._log(format_metrics(best_test_metrics, 'test')) self._log('\n' + '*' * 40 + ' Training done ' + '*' * 40)
def restore_checkpoint(self, sess, model_checkpoint=None): """Loads model checkpoint if found and computes evaluation metrics.""" if model_checkpoint is None or not tf.train.checkpoint_exists( model_checkpoint): self.init_model_weights(sess) else: self._log( 'Loading existing model saved at {}'.format(model_checkpoint)) self.saver.restore(sess, model_checkpoint) self.global_step = int(model_checkpoint.split('-')[-1]) val_metrics = self._eval_model(sess, 'val') test_metrics = self._eval_model(sess, 'test') self._log(format_metrics(val_metrics, 'val')) self._log(format_metrics(test_metrics, 'test'))
def eval(self): model = self.model data = self.data args = self.args save_dir = self.args.save_dir args.np_seed = None if not self.best_test_metrics and args.test_prop > 0: model.eval() self.best_emb = model.encode(data['features'], self.adj_train_enc) self.best_test_metrics = model.compute_metrics(self.best_emb, data, 'test') ## CLUSTERING EVAL START if args.node_cluster == 1: metrics_clustering, pred_label = model.eval_cluster(self.best_emb, data, 'all') self.best_test_metrics.update(metrics_clustering) self.best_val_metrics = self.best_test_metrics else: metrics_clustering, pred_label = model.eval_cluster(self.best_emb, data, 'all') self.best_test_metrics.update(metrics_clustering) ## CLUSTERING EVAL END logging.info(" ".join(["Val set results:", format_metrics(self.best_val_metrics, 'val')])) logging.info(" ".join(["Test set results:", format_metrics(self.best_test_metrics, 'test')])) if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), self.best_emb.cpu().detach().numpy()) if hasattr(model.encoder, 'att_adj'): filename = os.path.join(save_dir, args.dataset + '_att_adj.p') pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) print('Dumped attention adj: ' + filename) json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth')) logging.info(f"Saved model in {save_dir}") return self.best_emb, pred_label
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if int(args.double_precision): torch.set_default_dtype(torch.float64) if int(args.cuda) >= 0: torch.cuda.manual_seed(args.seed) args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' args.patience = args.epochs if not args.patience else int(args.patience) logging.getLogger().setLevel(logging.INFO) if args.save: if not args.save_dir: dt = datetime.datetime.now() date = f"{dt.year}_{dt.month}_{dt.day}" models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date) save_dir = get_dir_name(models_dir) else: save_dir = args.save_dir logging.basicConfig(level=logging.INFO, handlers=[ logging.FileHandler( os.path.join(save_dir, 'log.txt')), logging.StreamHandler() ]) logging.info(f'Using: {args.device}') logging.info("Using seed {}.".format(args.seed)) # Load data data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset)) args.n_nodes, args.feat_dim = data['features'].shape if args.task == 'nc': Model = NCModel args.n_classes = int(data['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: args.nb_false_edges = len(data['train_edges_false']) args.nb_edges = len(data['train_edges']) if args.task == 'lp': Model = LPModel if not args.lr_reduce_freq: args.lr_reduce_freq = args.epochs # Model and optimizer model = Model(args) logging.info(str(model)) optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int( args.lr_reduce_freq), gamma=float(args.gamma)) tot_params = sum([np.prod(p.size()) for p in model.parameters()]) logging.info(f"Total number of parameters: {tot_params}") if args.cuda is not None and int(args.cuda) >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) model = model.to(args.device) for x, val in data.items(): if torch.is_tensor(data[x]): data[x] = data[x].to(args.device) # Train model t_total = time.time() counter = 0 best_val_metrics = model.init_metric_dict() best_test_metrics = None best_emb = None for epoch in range(args.epochs): t = time.time() model.train() optimizer.zero_grad() embeddings = model.encode(data['features'], data['adj_train_norm']) train_metrics = model.compute_metrics(embeddings, data, 'train', args) train_metrics['loss'].backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(model.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() lr_scheduler.step() if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), 'lr: {}'.format(lr_scheduler.get_lr()[0]), format_metrics(train_metrics, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], data['adj_train_norm']) val_metrics = model.compute_metrics(embeddings, data, 'val', args) if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val') ])) if model.has_improved(best_val_metrics, val_metrics): best_test_metrics = model.compute_metrics( embeddings, data, 'test', args) if isinstance(embeddings, tuple): best_emb = torch.cat( (pmath.logmap0(embeddings[0], c=1.0), embeddings[1]), dim=1).cpu() else: best_emb = embeddings.cpu() if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) best_val_metrics = val_metrics counter = 0 else: counter += 1 if counter == args.patience and epoch > args.min_epochs: logging.info("Early stopping") break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if not best_test_metrics: model.eval() best_emb = model.encode(data['features'], data['adj_train_norm']) best_test_metrics = model.compute_metrics(best_emb, data, 'test', args) logging.info(" ".join( ["Val set results:", format_metrics(best_val_metrics, 'val')])) logging.info(" ".join( ["Test set results:", format_metrics(best_test_metrics, 'test')])) if args.save: if isinstance(best_emb, tuple): best_emb = torch.cat( (pmath.logmap0(best_emb[0], c=1.0), best_emb[1]), dim=1).cpu() else: best_emb = best_emb.cpu() np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) if hasattr(model.encoder, 'att_adj'): filename = os.path.join(save_dir, args.dataset + '_att_adj.p') pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) print('Dumped attention adj: ' + filename) json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth')) logging.info(f"Saved model in {save_dir}")
if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], data['adj_train_norm']) val_metrics = model.compute_metrics(embeddings, data, 'val') if model.has_improved(best_val_metrics, val_metrics): best_test_metrics = model.compute_metrics(embeddings, data, 'test') best_pred_metrics = model.compute_metrics(embeddings, data, 'pred') best_emb = embeddings.cpu() best_val_metrics = val_metrics counter = 0 else: counter += 1 if counter == args.patience and epoch > args.min_epochs: print("Early stopping") break print("Optimization Finished!") print("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if not best_test_metrics: model.eval() best_emb = model.encode(data['features'], data['adj_train_norm']) best_test_metrics = model.compute_metrics(best_emb, data, 'test') best_pred_metrics = model.compute_metrics(best_emb, data, 'pred') print(" ".join(["Val set results:", format_metrics(best_val_metrics, 'val')])) print(" ".join( ["Test set results:", format_metrics(best_test_metrics, 'test')])) print(" ".join( ["Pred set results:", format_metrics(best_pred_metrics, 'pred')]))
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if int(args.double_precision): torch.set_default_dtype(torch.float64) if int(args.cuda) >= 0: torch.cuda.manual_seed(args.seed) args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' args.patience = args.epochs if not args.patience else int(args.patience) logging.getLogger().setLevel(logging.INFO) if args.save: if not args.save_dir: dt = datetime.datetime.now() date = f"{dt.year}_{dt.month}_{dt.day}" models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date) save_dir = get_dir_name(models_dir) else: save_dir = args.save_dir logging.basicConfig(level=logging.INFO, handlers=[ logging.FileHandler( os.path.join(save_dir, 'log.txt')), logging.StreamHandler() ]) logging.info(f'Using: {args.device}') logging.info("Using seed {}.".format(args.seed)) reserve_mark = 0 if args.task == 'nc': reserve_mark = 0 else: args.task = 'nc' reserve_mark = 1 # Load data data = load_data(args, os.path.join('data/', args.dataset)) args.n_nodes, args.feat_dim = data['features'].shape if args.task == 'nc': Model = ADVNCModel args.n_classes = int(data['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: args.nb_false_edges = len(data['train_edges_false']) args.nb_edges = len(data['train_edges']) if args.task == 'lp': Model = ADVLPModel else: Model = RECModel # No validation for reconstruction task args.eval_freq = args.epochs + 1 #transfer loading if reserve_mark == 1: args.task = 'lp' # reset reserve mark reserve_mark = 0 if args.task == 'lp': reserve_mark = 0 else: args.task = 'lp' reserve_mark = 1 data1 = load_data(args, os.path.join('data/', args.dataset)) args.n_nodes, args.feat_dim = data1['features'].shape if args.task == 'nc': Model = ADVNCModel args.n_classes = int(data1['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: print('*****') args.nb_false_edges = len(data1['train_edges_false']) args.nb_edges = len(data1['train_edges']) if args.task == 'lp': Model = ADVLPModel else: Model = RECModel # No validation for reconstruction task args.eval_freq = args.epochs + 1 if reserve_mark == 1: args.task = 'nc' if args.task == 'nc': Model = ADVNCModel else: Model = ADVLPModel if not args.lr_reduce_freq: args.lr_reduce_freq = args.epochs # Model and optimizer model = Model(args) logging.info(str(model)) optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_en = getattr(optimizers, args.optimizer)(params=model.encoder.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer_de = getattr(optimizers, args.optimizer)(params=model.decoder.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int( args.lr_reduce_freq), gamma=float(args.gamma)) lr_scheduler_en = torch.optim.lr_scheduler.StepLR(optimizer_en, step_size=int( args.lr_reduce_freq), gamma=float(args.gamma)) lr_scheduler_de = torch.optim.lr_scheduler.StepLR(optimizer_de, step_size=int( args.lr_reduce_freq), gamma=float(args.gamma)) tot_params = sum([np.prod(p.size()) for p in model.parameters()]) logging.info(f"Total number of parameters: {tot_params}") if args.cuda is not None and int(args.cuda) >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) model = model.to(args.device) for x, val in data.items(): if torch.is_tensor(data[x]): data[x] = data[x].to(args.device) for x, val in data1.items(): if torch.is_tensor(data1[x]): data1[x] = data1[x].to(args.device) # Train model t_total = time.time() counter = 0 best_val_metrics = model.init_metric_dict() best_test_metrics = None best_emb = None for epoch in range(args.epochs): t = time.time() model.train() model.mark = 1 # if epoch%3==0: # model.save_net() # model.save_emb() # lr_scheduler.step() # if epoch%3==1: # if epoch > 100: optimizer.zero_grad() optimizer_en.zero_grad() # model.load_emb() embeddings1 = model.encode(data1['features'], data1['adj_train_norm']) train_metrics1 = model.compute_metrics1(embeddings1, data1, 'train') loss1 = 10 * train_metrics1['loss'] loss1.backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(model.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() # model.load_net() # # lr_scheduler.step() # if (args.model == 'GCN' or args.model == 'GAT' and epoch > 100) or (args.model == 'HGCN' and epoch > 1000): optimizer_en.zero_grad() embeddings = model.encode(data['features'], data['adj_train_norm']) train_metrics = model.compute_metrics(embeddings, data, 'train') loss = -(train_metrics['loss']) # - train_metrics['loss_shuffle']) loss.backward() optimizer_en.step() # # # if epoch%3==2: optimizer_de.zero_grad() embeddings2 = model.encode(data['features'], data['adj_train_norm']).detach_() train_metrics2 = model.compute_metrics(embeddings2, data, 'train') loss2 = (train_metrics2['loss']) # - train_metrics2['loss_shuffle']) loss2.backward() optimizer_de.step() lr_scheduler.step() lr_scheduler_en.step() lr_scheduler_de.step() # if epoch<100: # train_metrics2 = train_metrics if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), 'lr: {}'.format(lr_scheduler.get_lr()[0]), format_metrics(train_metrics1, 'train'), format_metrics(train_metrics2, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if not best_val_metrics == None: logging.info(" ".join([ "Val set results:", format_metrics(best_val_metrics, 'val') ])) logging.info(" ".join([ "Val set results:", format_metrics(best_val_metrics1, 'val') ])) logging.info(" ".join([ "Test set results:", format_metrics(best_test_metrics, 'test') ])) logging.info(" ".join([ "Test set results:", format_metrics(best_test_metrics1, 'test') ])) if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], data['adj_train_norm']) val_metrics = model.compute_metrics(embeddings, data, 'val') embeddings1 = model.encode(data1['features'], data1['adj_train_norm']) val_metrics1 = model.compute_metrics1(embeddings1, data1, 'val') if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val'), format_metrics(val_metrics1, 'val') ])) embeddings = model.encode(data['features'], data['adj_train_norm']) test_metrics = model.compute_metrics(embeddings, data, 'test') embeddings1 = model.encode(data1['features'], data1['adj_train_norm']) test_metrics1 = model.compute_metrics1(embeddings1, data1, 'test') if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), format_metrics(test_metrics, 'test'), format_metrics(test_metrics1, 'test') ])) if model.has_improved(best_val_metrics, val_metrics1): best_test_metrics = model.compute_metrics1( embeddings1, data1, 'test') best_test_metrics1 = model.compute_metrics( embeddings, data, 'test') best_emb = embeddings1.cpu() if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) best_val_metrics = val_metrics1 best_val_metrics1 = val_metrics1 counter = 0 # else: # counter += 1 # if counter == args.patience and epoch > args.min_epochs: # logging.info("Early stopping") # break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if not best_test_metrics: model.eval() best_emb = model.encode(data['features'], data['adj_train_norm']) best_test_metrics = model.compute_metrics(best_emb, data, 'test') logging.info(" ".join( ["Val set results:", format_metrics(best_val_metrics, 'val')])) logging.info(" ".join( ["Test set results:", format_metrics(best_test_metrics, 'test')])) if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.cpu().detach().numpy()) if hasattr(model.encoder, 'att_adj'): filename = os.path.join(save_dir, args.dataset + '_att_adj.p') pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) print('Dumped attention adj: ' + filename) json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth')) logging.info(f"Saved model in {save_dir}")
def train(args): if int(args.double_precision): torch.set_default_dtype(torch.float64) args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' args.patience = args.epochs if not args.patience else int(args.patience) logging.getLogger().setLevel(logging.INFO) logging.info(f'Using: {args.device}') logging.info("Using seed {}.".format(args.seed)) data = load_data(args, os.path.join('./data', args.dataset)) args.n_nodes, args.feat_dim = data['features'].shape if args.task == 'nc': Model = NCModel args.n_classes = int(data['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: args.nb_false_edges = len(data['train_edges_false']) args.nb_edges = len(data['train_edges']) if args.task == 'lp': Model = LPModel else: Model = RECModel args.eval_freq = args.epochs + 1 # Model and optimizer model = Model(args) optimizer, lr_scheduler, stiefel_optimizer, stiefel_lr_scheduler = \ set_up_optimizer_scheduler(False, args, model, args.lr, args.lr_stie) tot_params = sum([np.prod(p.size()) for p in model.parameters()]) if args.cuda is not None and int(args.cuda) >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) model = model.to(args.device) for x, val in data.items(): if torch.is_tensor(data[x]): data[x] = data[x].to(args.device) # Train model t_total = time.time() counter = 0 best_val_metrics = model.init_metric_dict() best_test_metrics = None best_emb = None for epoch in range(args.epochs): t = time.time() model.train() optimizer.zero_grad() stiefel_optimizer.zero_grad() embeddings = model.encode(data['features'], data['hgnn_adj'], data['hgnn_weight']) train_metrics = model.compute_metrics(embeddings, data, 'train') train_metrics['loss'].backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(model.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() stiefel_optimizer.step() lr_scheduler.step() stiefel_lr_scheduler.step() if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), 'lr: {:04f}, stie_lr: {:04f}'.format( lr_scheduler.get_lr()[0], stiefel_lr_scheduler.get_lr()[0]), format_metrics(train_metrics, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], data['hgnn_adj'], data['hgnn_weight']) for i in range(embeddings.size(0)): if (embeddings[i] != embeddings[i]).sum() > 1: print('PART train i', i, 'embeddings[i]', embeddings[i]) val_metrics = model.compute_metrics(embeddings, data, 'val') if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val') ])) if model.has_improved(best_val_metrics, val_metrics): best_test_metrics = model.compute_metrics( embeddings, data, 'test') best_emb = embeddings.cpu() if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) best_val_metrics = val_metrics counter = 0 else: counter += 1 if counter == args.patience and epoch > args.min_epochs: logging.info("Early stopping") break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if args.save: np.save(os.path.join(save_dir, str(args.dataset) + '_embeddings.npy'), best_emb.cpu().detach().numpy()) torch.save( model.state_dict(), os.path.join( save_dir, str(args.dataset) + 'model_auc' + str(best_rocauc) + '.pth')) logging.info(f"Saved model in {save_dir}") if args.task == 'lp': return best_test_metrics['roc'] if args.task == 'nc': return best_test_metrics['f1']
def fit(self): args = self.args model = self.model optimizer = self.optimizer data = self.data lr_scheduler = self.lr_scheduler save_dir = self.args.save_dir for epoch in range(args.epochs): t = time.time() model.train() optimizer.zero_grad() embeddings = model.encode(data['features'], self.adj_train_enc) train_metrics = model.compute_metrics(embeddings, data, 'train', epoch) train_metrics['loss'].backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(model.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() lr_scheduler.step() with torch.no_grad(): if (epoch + 1) % args.log_freq == 0: logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), 'lr: {}'.format(lr_scheduler.get_lr()[0]), format_metrics(train_metrics, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], self.adj_train_enc) if args.node_cluster != 1: ## Link Prediction Task that use train/val/test val_metrics = model.compute_metrics(embeddings, data, 'val') if (epoch + 1) % args.log_freq == 0: logging.info(" ".join(['Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val')])) if model.has_improved(self.best_val_metrics, val_metrics): self.best_test_metrics = model.compute_metrics(embeddings, data, 'test') self.best_emb = embeddings if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), self.best_emb.cpu().detach().numpy()) self.best_val_metrics = val_metrics self.best_val_metrics['epoch'] = epoch + 1 self.counter = 0 # logging.info("improved") else: # logging.info("not improved :"+str(self.counter)) self.counter += 1 if self.counter >= args.patience and epoch > args.min_epochs: # NOTE : fixed when improve only epoch0 logging.info("Early stopping") break else: ## Node Clustering Task that use 100 % trainset. if self.best_test_metrics.get('loss', 999) > train_metrics['loss']: ''' # NOTE : when kmeans calculated, affect np.state and takes time, and not fair to monitor it. # metrics_clustering = model.eval_cluster(embeddings, data, 'all') # logging.info(" ".join(["Cluster results:", format_metrics(metrics_clustering, 'all')])) ''' self.best_emb = embeddings logging.info("Best loss found") self.best_test_metrics['loss'] = train_metrics['loss'] self.best_test_metrics['epoch'] = epoch + 1 self.counter = 0 else: self.counter += 1 if self.counter >= args.patience and epoch > args.min_epochs: # NOTE : fixed when improve only epoch0 logging.info("Early stopping") break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - self.t_total))
def train(args): np.random.seed(args.seed) torch.manual_seed(args.seed) if int(args.double_precision): torch.set_default_dtype(torch.float64) if int(args.cuda) >= 0: torch.cuda.manual_seed(args.seed) args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' args.patience = args.epochs if not args.patience else int(args.patience) logging.getLogger().setLevel(logging.INFO) if args.save: if not args.save_dir: dt = datetime.datetime.now() date = f"{dt.year}_{dt.month}_{dt.day}" models_dir = os.path.join(os.environ['LOG_DIR'], args.task, date) save_dir = get_dir_name(models_dir) else: save_dir = args.save_dir logging.basicConfig(level=logging.INFO, handlers=[ logging.FileHandler( os.path.join(save_dir, 'log.txt')), logging.StreamHandler() ]) logging.info(f'Using: {args.device}') logging.info("Using seed {}.".format(args.seed)) warnings.filterwarnings(action='ignore') # Load data data = load_data(args, os.path.join(os.environ['DATAPATH'], args.dataset)) args.n_nodes, args.feat_dim = data['features'].shape if args.task == 'nc': Model = NCModel args.n_classes = int(data['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: args.nb_false_edges = len(data['train_edges_false']) args.nb_edges = len(data['train_edges']) if args.task == 'lp': Model = LPModel if not args.lr_reduce_freq: args.lr_reduce_freq = args.epochs # Model and optimizer # Initialize RL environment lr_q = args.lr_q # RL Learning Rate action_space1 = ['reject', 'accept'] # HGNN action space action_space2 = [ # ACE action space 'r, r', 'r, a', 'a, r', 'a, a' ] joint_actions = [] for i in range(len(action_space1)): # Joint action space for j in range(len(action_space2)): joint_actions.append((i, j)) env = Env(theta=args.theta, initial_c=args.c) Agent1 = QLearningTable(actions=list(range(len(action_space1))), joint=joint_actions, start=args.start_q, learning_rate=lr_q) Agent2 = QLearningTable(actions=list(range(len(action_space2))), joint=joint_actions, start=args.start_q, learning_rate=lr_q) hgnn = Model(args) # Agent1 HGNN ace = Model(args) # Agent2 ACE logging.info(str(hgnn)) optimizer = getattr(optimizers, args.optimizer)(params=hgnn.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int( args.lr_reduce_freq), gamma=float(args.gamma)) tot_params = sum([np.prod(p.size()) for p in hgnn.parameters()]) logging.info(f"Total number of parameters: {tot_params}") if args.cuda is not None and int(args.cuda) >= 0: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) hgnn = hgnn.to(args.device) ace = ace.to(args.device) for x, val in data.items(): if torch.is_tensor(data[x]): data[x] = data[x].to(args.device) # Train model t_total = time.time() counter = 0 best_val_metrics = hgnn.init_metric_dict() best_test_metrics = None best_emb = None val_metric_record = [] train_metric_record = [] for epoch in range(args.epochs): t = time.time() hgnn.train() ace.train() optimizer.zero_grad() # train model with RL and return Agent1's train metrics # Terminate mechanism if epoch > args.start_q + 30: r1 = np.array(env.c1_record)[-30:-1, 0] r2 = np.array(env.c1_record)[-30:-1, 1] if abs(max(r1) - min(r1)) <= 0.03 and not env.stop[0]: env.stop[0] = True print("Layer1 RL terminate at {:.3f}.".format( env.c1_record[-1][0])) counter = args.patience // 2 if abs(max(r2) - min(r2)) <= 0.03 and not env.stop[1]: env.stop[1] = True print("Layer2 RL terminate at {:.3f}.".format( env.c1_record[-1][1])) counter = args.patience // 2 train_metrics = hgnn.train_with_RL(env, Agent1, Agent2, data, epoch, ace) train_metric_record.append(train_metrics[hgnn.key_param]) train_metrics['loss'].backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(hgnn.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() lr_scheduler.step() if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), 'lr: {}'.format(lr_scheduler.get_lr()[0]), format_metrics(train_metrics, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if (epoch + 1) % args.eval_freq == 0: hgnn.eval() embeddings = hgnn.encode(data['features'], data['adj_train_norm']) val_metrics = hgnn.compute_metrics(embeddings, data, 'val') val_metric_record.append(val_metrics[hgnn.key_param]) if (epoch + 1) % args.log_freq == 0: logging.info(" ".join([ 'Epoch: {:04d}'.format(epoch + 1), format_metrics(val_metrics, 'val') ])) if hgnn.has_improved(best_val_metrics, val_metrics): best_test_metrics = hgnn.compute_metrics( embeddings, data, 'test') best_emb = embeddings.cpu() if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) best_val_metrics = val_metrics counter = 0 else: counter += 1 if counter >= args.patience and epoch > args.min_epochs and all( env.stop): logging.info("Early stopping") break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if not best_test_metrics: hgnn.eval() best_emb = hgnn.encode(data['features'], data['adj_train_norm']) best_test_metrics = hgnn.compute_metrics(best_emb, data, 'test') logging.info(" ".join( ["Val set results:", format_metrics(best_val_metrics, 'val')])) logging.info(" ".join( ["Test set results:", format_metrics(best_test_metrics, 'test')])) if args.save: # Save embeddings and attentions np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.cpu().detach().numpy()) if hasattr(hgnn.encoder, 'att_adj'): filename = os.path.join(save_dir, args.dataset + '_att_adj.p') pickle.dump(hgnn.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) print('Dumped attention adj: ' + filename) # Save model json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) torch.save(hgnn.state_dict(), os.path.join(save_dir, 'model.pth')) # Save curvature record and figures np.save(os.path.join(save_dir, 'curv1.npy'), np.array(env.c1_record)) np.save(os.path.join(save_dir, 'curv2.npy'), np.array(env.c2_record)) # Save acc record np.save(os.path.join(save_dir, 'metric_record.npy'), np.array([train_metric_record, val_metric_record])) logging.info("Agent1: {}, Agent2: {}".format(env.c1, env.c2)) logging.info(f"Saved model in {save_dir}")
def train(args): if args.dataset in ['disease_lp', 'disease_nc']: args.normalize_feats = 0 if args.task == 'nc': args.num_layers += 1 if args.manifold == 'Lorentzian' or 'Euclidean': args.dim = args.dim + 1 args.c = float(args.c) if args.c != None else None np.random.seed(args.seed) torch.manual_seed(args.seed) if int(args.double_precision): torch.set_default_dtype(torch.float64) if int(args.cuda) >= 0: torch.cuda.manual_seed(args.seed) args.device = 'cuda:' + str(args.cuda) if int(args.cuda) >= 0 else 'cpu' args.patience = args.epochs if not args.patience else int(args.patience) logging.getLogger().setLevel(logging.INFO) if args.save: if not args.save_dir: dt = datetime.datetime.now() date = f"{dt.year}_{dt.month}_{dt.day}" models_dir = os.path.join('save', args.task, date) save_dir = get_dir_name(models_dir) if not os.path.exists(save_dir): os.makedirs(save_dir) else: save_dir = args.save_dir logging.basicConfig(level=logging.INFO, handlers=[ logging.FileHandler(os.path.join(save_dir, 'log.txt')), logging.StreamHandler() ]) logging.info(f'Using cuda: {args.cuda}') logging.info("Using seed {}.".format(args.seed)) logging.info("Using dataset {}.".format(args.dataset)) # Load data data = load_data(args, os.path.join('data/', args.dataset)) args.n_nodes, args.feat_dim = data['features'].shape if args.task == 'nc': Model = NCModel args.n_classes = int(data['labels'].max() + 1) logging.info(f'Num classes: {args.n_classes}') else: args.nb_false_edges = len(data['train_edges_false']) args.nb_edges = len(data['train_edges']) if args.task == 'lp': Model = LPModel else: raise NotImplementedError Model = RECModel # No validation for reconstruction task args.eval_freq = args.epochs + 1 if not args.lr_reduce_freq: args.lr_reduce_freq = args.epochs # Model and optimizer model = Model(args) if args.cuda is not None and int(args.cuda) >= 0: model = model.to(args.device) for x, val in data.items(): if torch.is_tensor(data[x]): data[x] = data[x].to(args.device) for iter_i in [str(args.run_times)]: optimizer = getattr(optimizers, args.optimizer)(params=model.parameters(), lr=args.lr, weight_decay=args.weight_decay) lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=int(args.lr_reduce_freq), gamma=float(args.gamma) ) # Train model t_total = time.time() counter = 0 best_val_metrics = model.init_metric_dict() best_test_metrics = None best_emb = None for epoch in range(args.epochs): t = time.time() model.train() optimizer.zero_grad() embeddings = model.encode(data['features'], data['adj_train_norm']) train_metrics = model.compute_metrics(embeddings, data, 'train') train_metrics['loss'].backward() if args.grad_clip is not None: max_norm = float(args.grad_clip) all_params = list(model.parameters()) for param in all_params: torch.nn.utils.clip_grad_norm_(param, max_norm) optimizer.step() lr_scheduler.step() if (epoch + 1) % args.log_freq == 0: logging.info(" ".join(['{} {} Epoch: {} {:04d}'.format(args.dataset, args.task, iter_i, epoch + 1), 'lr: {}'.format(lr_scheduler.get_lr()[0]), format_metrics(train_metrics, 'train'), 'time: {:.4f}s'.format(time.time() - t) ])) if (epoch + 1) % args.eval_freq == 0: model.eval() embeddings = model.encode(data['features'], data['adj_train_norm']) val_metrics = model.compute_metrics(embeddings, data, 'val') if (epoch + 1) % args.log_freq == 0: logging.info(" ".join(['{} {} Epoch: {} {:04d}'.format(args.dataset, args.task, iter_i, epoch + 1), format_metrics(val_metrics, 'val')])) if model.has_improved(best_val_metrics, val_metrics): best_test_metrics = model.compute_metrics(embeddings, data, 'test') best_emb = embeddings.cpu() if args.save: np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.detach().numpy()) best_val_metrics = val_metrics counter = 0 else: counter += 1 if counter == args.patience and epoch > args.min_epochs: logging.info("Early stopping") break logging.info("Optimization Finished!") logging.info("Total time elapsed: {:.4f}s".format(time.time() - t_total)) if not best_test_metrics: model.eval() best_emb = model.encode(data['features'], data['adj_train_norm']) best_test_metrics = model.compute_metrics(best_emb, data, 'test') logging.info(" ".join(["Val set results:", format_metrics(best_val_metrics, 'val')])) logging.info(" ".join(["Test set results:", format_metrics(best_test_metrics, 'test')])) if args.save: print('Saved path', os.path.join(save_dir, 'embeddings.npy')) np.save(os.path.join(save_dir, 'embeddings.npy'), best_emb.cpu().detach().numpy()) if hasattr(model.encoder, 'att_adj'): filename = os.path.join(save_dir, args.dataset + '_att_adj.p') pickle.dump(model.encoder.att_adj.cpu().to_dense(), open(filename, 'wb')) print('Dumped attention adj: ' + filename) json.dump(vars(args), open(os.path.join(save_dir, 'config.json'), 'w')) torch.save(model.state_dict(), os.path.join(save_dir, 'model.pth')) logging.info(f"Saved model in {save_dir}")