def train(args): print(args) dataset = MovieLens(args.data_name, args.device, use_one_hot_fea=args.use_one_hot_fea, symm=args.gcn_agg_norm_symm, test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio) print("Loading data finished ...\n") args.src_in_units = dataset.user_feature_shape[1] args.dst_in_units = dataset.movie_feature_shape[1] args.rating_vals = dataset.possible_rating_values ### build the net net = Net(args=args) net = net.to(args.device) nd_possible_rating_values = th.FloatTensor( dataset.possible_rating_values).to(args.device) rating_loss_net = nn.CrossEntropyLoss() learning_rate = args.train_lr optimizer = get_optimizer(args.train_optimizer)(net.parameters(), lr=learning_rate) print("Loading network finished ...\n") ### perpare training data train_gt_labels = dataset.train_labels train_gt_ratings = dataset.train_truths ### prepare the logger train_loss_logger = MetricLogger( ['iter', 'loss', 'rmse'], ['%d', '%.4f', '%.4f'], os.path.join(args.save_dir, 'train_loss%d.csv' % args.save_id)) valid_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'], os.path.join( args.save_dir, 'valid_loss%d.csv' % args.save_id)) test_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'], os.path.join( args.save_dir, 'test_loss%d.csv' % args.save_id)) ### declare the loss information best_valid_rmse = np.inf no_better_valid = 0 best_iter = -1 count_rmse = 0 count_num = 0 count_loss = 0 dataset.train_enc_graph = dataset.train_enc_graph.to(args.device) dataset.train_dec_graph = dataset.train_dec_graph.to(args.device) dataset.valid_enc_graph = dataset.train_enc_graph dataset.valid_dec_graph = dataset.valid_dec_graph.to(args.device) dataset.test_enc_graph = dataset.test_enc_graph.to(args.device) dataset.test_dec_graph = dataset.test_dec_graph.to(args.device) print("Start training ...") dur = [] for iter_idx in range(1, args.train_max_iter): if iter_idx > 3: t0 = time.time() net.train() pred_ratings = net(dataset.train_enc_graph, dataset.train_dec_graph, dataset.user_feature, dataset.movie_feature) loss = rating_loss_net(pred_ratings, train_gt_labels).mean() count_loss += loss.item() optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip) optimizer.step() if iter_idx > 3: dur.append(time.time() - t0) if iter_idx == 1: print("Total #Param of net: %d" % (torch_total_param_num(net))) print( torch_net_info(net, save_path=os.path.join( args.save_dir, 'net%d.txt' % args.save_id))) real_pred_ratings = (th.softmax(pred_ratings, dim=1) * nd_possible_rating_values.view(1, -1)).sum(dim=1) rmse = ((real_pred_ratings - train_gt_ratings)**2).sum() count_rmse += rmse.item() count_num += pred_ratings.shape[0] if iter_idx % args.train_log_interval == 0: train_loss_logger.log(iter=iter_idx, loss=count_loss / (iter_idx + 1), rmse=count_rmse / count_num) logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format( iter_idx, count_loss / iter_idx, count_rmse / count_num, np.average(dur)) count_rmse = 0 count_num = 0 if iter_idx % args.train_valid_interval == 0: valid_rmse = evaluate(args=args, net=net, dataset=dataset, segment='valid') valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse) logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse) if valid_rmse < best_valid_rmse: best_valid_rmse = valid_rmse no_better_valid = 0 best_iter = iter_idx test_rmse = evaluate(args=args, net=net, dataset=dataset, segment='test') best_test_rmse = test_rmse test_loss_logger.log(iter=iter_idx, rmse=test_rmse) logging_str += ', Test RMSE={:.4f}'.format(test_rmse) else: no_better_valid += 1 if no_better_valid > args.train_early_stopping_patience\ and learning_rate <= args.train_min_lr: logging.info( "Early stopping threshold reached. Stop training.") break if no_better_valid > args.train_decay_patience: new_lr = max(learning_rate * args.train_lr_decay_factor, args.train_min_lr) if new_lr < learning_rate: learning_rate = new_lr logging.info("\tChange the LR to %g" % new_lr) for p in optimizer.param_groups: p['lr'] = learning_rate no_better_valid = 0 if iter_idx % args.train_log_interval == 0: print(logging_str) print('Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'. format(best_iter, best_valid_rmse, best_test_rmse)) train_loss_logger.close() valid_loss_logger.close() test_loss_logger.close()
def train(args): print(args) dataset = MovieLens(args.data_name, args.ctx, use_one_hot_fea=args.use_one_hot_fea, symm=args.gcn_agg_norm_symm, test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio) print("Loading data finished ...\n") args.src_in_units = dataset.user_feature_shape[1] args.dst_in_units = dataset.movie_feature_shape[1] args.rating_vals = dataset.possible_rating_values ### build the net net = Net(args=args) net.initialize(init=mx.init.Xavier(factor_type='in'), ctx=args.ctx) net.hybridize() nd_possible_rating_values = mx.nd.array(dataset.possible_rating_values, ctx=args.ctx, dtype=np.float32) rating_loss_net = gluon.loss.SoftmaxCELoss() rating_loss_net.hybridize() trainer = gluon.Trainer(net.collect_params(), args.train_optimizer, {'learning_rate': args.train_lr}) print("Loading network finished ...\n") ### perpare training data train_gt_labels = dataset.train_labels train_gt_ratings = dataset.train_truths ### prepare the logger train_loss_logger = MetricLogger( ['iter', 'loss', 'rmse'], ['%d', '%.4f', '%.4f'], os.path.join(args.save_dir, 'train_loss%d.csv' % args.save_id)) valid_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'], os.path.join( args.save_dir, 'valid_loss%d.csv' % args.save_id)) test_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'], os.path.join( args.save_dir, 'test_loss%d.csv' % args.save_id)) ### declare the loss information best_valid_rmse = np.inf no_better_valid = 0 best_iter = -1 avg_gnorm = 0 count_rmse = 0 count_num = 0 count_loss = 0 dataset.train_enc_graph = dataset.train_enc_graph.to(args.ctx) dataset.train_dec_graph = dataset.train_dec_graph.to(args.ctx) dataset.valid_enc_graph = dataset.train_enc_graph dataset.valid_dec_graph = dataset.valid_dec_graph.to(args.ctx) dataset.test_enc_graph = dataset.test_enc_graph.to(args.ctx) dataset.test_dec_graph = dataset.test_dec_graph.to(args.ctx) print("Start training ...") dur = [] for iter_idx in range(1, args.train_max_iter): if iter_idx > 3: t0 = time.time() with mx.autograd.record(): pred_ratings = net(dataset.train_enc_graph, dataset.train_dec_graph, dataset.user_feature, dataset.movie_feature) loss = rating_loss_net(pred_ratings, train_gt_labels).mean() loss.backward() count_loss += loss.asscalar() gnorm = params_clip_global_norm(net.collect_params(), args.train_grad_clip, args.ctx) avg_gnorm += gnorm trainer.step(1.0) if iter_idx > 3: dur.append(time.time() - t0) if iter_idx == 1: print("Total #Param of net: %d" % (gluon_total_param_num(net))) print( gluon_net_info(net, save_path=os.path.join( args.save_dir, 'net%d.txt' % args.save_id))) real_pred_ratings = (mx.nd.softmax(pred_ratings, axis=1) * nd_possible_rating_values.reshape( (1, -1))).sum(axis=1) rmse = mx.nd.square(real_pred_ratings - train_gt_ratings).sum() count_rmse += rmse.asscalar() count_num += pred_ratings.shape[0] if iter_idx % args.train_log_interval == 0: train_loss_logger.log(iter=iter_idx, loss=count_loss / (iter_idx + 1), rmse=count_rmse / count_num) logging_str = "Iter={}, gnorm={:.3f}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format( iter_idx, avg_gnorm / args.train_log_interval, count_loss / iter_idx, count_rmse / count_num, np.average(dur)) avg_gnorm = 0 count_rmse = 0 count_num = 0 if iter_idx % args.train_valid_interval == 0: valid_rmse = evaluate(args=args, net=net, dataset=dataset, segment='valid') valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse) logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse) if valid_rmse < best_valid_rmse: best_valid_rmse = valid_rmse no_better_valid = 0 best_iter = iter_idx #net.save_parameters(filename=os.path.join(args.save_dir, 'best_valid_net{}.params'.format(args.save_id))) test_rmse = evaluate(args=args, net=net, dataset=dataset, segment='test') best_test_rmse = test_rmse test_loss_logger.log(iter=iter_idx, rmse=test_rmse) logging_str += ', Test RMSE={:.4f}'.format(test_rmse) else: no_better_valid += 1 if no_better_valid > args.train_early_stopping_patience\ and trainer.learning_rate <= args.train_min_lr: logging.info( "Early stopping threshold reached. Stop training.") break if no_better_valid > args.train_decay_patience: new_lr = max( trainer.learning_rate * args.train_lr_decay_factor, args.train_min_lr) if new_lr < trainer.learning_rate: logging.info("\tChange the LR to %g" % new_lr) trainer.set_learning_rate(new_lr) no_better_valid = 0 if iter_idx % args.train_log_interval == 0: print(logging_str) print('Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'. format(best_iter, best_valid_rmse, best_test_rmse)) train_loss_logger.close() valid_loss_logger.close() test_loss_logger.close()