def load_from_ckpt(ckpt_file, epoch, model): ckpt_file = ckpt_file + '_{}'.format(epoch) if os.path.isfile(ckpt_file): print(get_time(), 'load from ckpt {}'.format(ckpt_file)) ckpt_state_dict = torch.load(ckpt_file) model.load_state_dict(ckpt_state_dict['model_state_dict']) print(get_time(), 'finish load from ckpt {}'.format(ckpt_file)) else: print('ckpt file does not exist {}'.format(ckpt_file))
def save_to_ckpt(ckpt_file, epoch, model, optimizer, lr_scheduler): ckpt_file = ckpt_file + '_{}'.format(epoch) print(get_time(), 'save to ckpt {}'.format(ckpt_file)) torch.save( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), }, ckpt_file) print(get_time(), 'finish save to ckpt {}'.format(ckpt_file))
def eval_cross_entropy_loss(model, device, loader, phase="Eval", sigma=1.0): """ formula in https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/MSR-TR-2010-82.pdf C = 0.5 * (1 - S_ij) * sigma * (si - sj) + log(1 + exp(-sigma * (si - sj))) when S_ij = 1: C = log(1 + exp(-sigma(si - sj))) when S_ij = -1: C = log(1 + exp(-sigma(sj - si))) sigma can change the shape of the curve """ print(get_time(), "{} Phase evaluate pairwise cross entropy loss".format(phase)) model.eval() with torch.set_grad_enabled(False): total_cost = 0 total_pairs = loader.get_num_pairs() pairs_in_compute = 0 for X, Y in loader.generate_batch_per_query(loader.df): Y = Y.reshape(-1, 1) rel_diff = Y - Y.T pos_pairs = (rel_diff > 0).astype(np.float32) num_pos_pairs = np.sum(pos_pairs, (0, 1)) # skip negative sessions, no relevant info: if num_pos_pairs == 0: continue neg_pairs = (rel_diff < 0).astype(np.float32) num_pairs = 2 * num_pos_pairs # num pos pairs and neg pairs are always the same pos_pairs = torch.tensor(pos_pairs, device=device) neg_pairs = torch.tensor(neg_pairs, device=device) Sij = pos_pairs - neg_pairs # only calculate the different pairs diff_pairs = pos_pairs + neg_pairs pairs_in_compute += num_pairs X_tensor = torch.Tensor(X).to(device) y_pred = model(X_tensor) y_pred_diff = y_pred - y_pred.t() # logsigmoid(x) = log(1 / (1 + exp(-x))) equivalent to log(1 + exp(-x)) C = 0.5 * (1 - Sij) * sigma * y_pred_diff - F.logsigmoid( -sigma * y_pred_diff) C = C * diff_pairs cost = torch.sum(C, (0, 1)) if cost.item() == float('inf') or np.isnan(cost.item()): import ipdb ipdb.set_trace() total_cost += cost assert total_pairs == pairs_in_compute avg_cost = total_cost / total_pairs print( get_time(), "{} Phase pairwise corss entropy loss {:.6f}, total_paris {}".format( phase, avg_cost.item(), total_pairs))
def eval_ndcg_at_k(inference_model, device, df_valid, valid_loader, batch_size, k_list, epoch, writer=None, phase="Eval"): # print("Eval Phase evaluate NDCG @ {}".format(k_list)) ndcg_metrics = {k: NDCG(k) for k in k_list} qids, rels, scores = [], [], [] inference_model.eval() with torch.no_grad(): for qid, rel, x in valid_loader.generate_query_batch( df_valid, batch_size): if x is None or x.shape[0] == 0: continue y_tensor = inference_model.forward(torch.Tensor(x).to(device)) scores.append(y_tensor.cpu().numpy().squeeze()) qids.append(qid) rels.append(rel) qids = np.hstack(qids) rels = np.hstack(rels) scores = np.hstack(scores) result_df = pd.DataFrame({'qid': qids, 'rel': rels, 'score': scores}) session_ndcgs = defaultdict(list) for qid in result_df.qid.unique(): result_qid = result_df[result_df.qid == qid].sort_values( 'score', ascending=False) rel_rank = result_qid.rel.values for k, ndcg in ndcg_metrics.items(): if ndcg.maxDCG(rel_rank) == 0: continue ndcg_k = ndcg.evaluate(rel_rank) if not np.isnan(ndcg_k): session_ndcgs[k].append(ndcg_k) ndcg_result = {k: np.mean(session_ndcgs[k]) for k in k_list} ndcg_result_print = ", ".join( ["NDCG@{}: {:.5f}".format(k, ndcg_result[k]) for k in k_list]) print(get_time(), "{} Phase evaluate {}".format(phase, ndcg_result_print)) if writer: for k in k_list: writer.add_scalars("metrics/NDCG@{}".format(k), {phase: ndcg_result[k]}, epoch) return ndcg_result
def baseline_pairwise_training_loop(epoch, net, loss_func, optimizer, train_loader, batch_size=100000, precision=torch.float32, device="cpu", debug=False): minibatch_loss = [] minibatch = 0 count = 0 for x_i, y_i, x_j, y_j in train_loader.generate_query_pair_batch( batch_size): if x_i is None or x_i.shape[0] == 0: continue x_i, x_j = torch.tensor(x_i, dtype=precision, device=device), torch.tensor(x_j, dtype=precision, device=device) # binary label y = torch.tensor((y_i > y_j).astype(np.float32), dtype=precision, device=device) net.zero_grad() y_pred = net(x_i, x_j) loss = loss_func(y_pred, y) loss.backward() count += 1 if count % 25 == 0 and debug: net.dump_param() optimizer.step() minibatch_loss.append(loss.item()) minibatch += 1 if minibatch % 100 == 0: print( get_time(), 'Epoch {}, Minibatch: {}, loss : {}'.format( epoch, minibatch, loss.item())) return np.mean(minibatch_loss)
def train( start_epoch=0, additional_epoch=100, lr=0.0001, optim="adam", leaky_relu=False, ndcg_gain_in_train="exp2", sigma=1.0, double_precision=False, standardize=False, small_dataset=False, debug=False, output_dir="/tmp/ranking_output/", ): print("start_epoch:{}, additional_epoch:{}, lr:{}".format( start_epoch, additional_epoch, lr)) writer = SummaryWriter(output_dir) precision = torch.float64 if double_precision else torch.float32 # get training and validation data: data_fold = 'Fold1' train_loader, df_train, valid_loader, df_valid = load_train_vali_data( data_fold, small_dataset) if standardize: df_train, scaler = train_loader.train_scaler_and_transform() df_valid = valid_loader.apply_scaler(scaler) lambdarank_structure = [136, 64, 16] net = LambdaRank(lambdarank_structure, leaky_relu=leaky_relu, double_precision=double_precision, sigma=sigma) device = get_device('LambdaRank') net.to(device) net.apply(init_weights) print(net) ckptfile = get_ckptdir('lambdarank', lambdarank_structure, sigma) if optim == "adam": optimizer = torch.optim.Adam(net.parameters(), lr=lr) elif optim == "sgd": optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9) else: raise ValueError( "Optimization method {} not implemented".format(optim)) print(optimizer) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.75) ideal_dcg = NDCG(2**9, ndcg_gain_in_train) for i in range(start_epoch, start_epoch + additional_epoch): net.train() net.zero_grad() count = 0 batch_size = 200 grad_batch, y_pred_batch = [], [] for X, Y in train_loader.generate_batch_per_query(shuffle=True): if np.sum(Y) == 0: # negative session, cannot learn useful signal continue N = 1.0 / ideal_dcg.maxDCG(Y) X_tensor = torch.tensor(X, dtype=precision, device=device) y_pred = net(X_tensor) y_pred_batch.append(y_pred) # compute the rank order of each document rank_df = pd.DataFrame({"Y": Y, "doc": np.arange(Y.shape[0])}) rank_df = rank_df.sort_values("Y").reset_index(drop=True) rank_order = rank_df.sort_values("doc").index.values + 1 with torch.no_grad(): pos_pairs_score_diff = 1.0 + torch.exp(sigma * (y_pred - y_pred.t())) Y_tensor = torch.tensor(Y, dtype=precision, device=device).view(-1, 1) rel_diff = Y_tensor - Y_tensor.t() pos_pairs = (rel_diff > 0).type(precision) neg_pairs = (rel_diff < 0).type(precision) Sij = pos_pairs - neg_pairs if ndcg_gain_in_train == "exp2": gain_diff = torch.pow(2.0, Y_tensor) - torch.pow( 2.0, Y_tensor.t()) elif ndcg_gain_in_train == "identity": gain_diff = Y_tensor - Y_tensor.t() else: raise ValueError( "ndcg_gain method not supported yet {}".format( ndcg_gain_in_train)) rank_order_tensor = torch.tensor(rank_order, dtype=precision, device=device).view(-1, 1) decay_diff = 1.0 / torch.log2(rank_order_tensor + 1.0) - 1.0 / torch.log2( rank_order_tensor.t() + 1.0) delta_ndcg = torch.abs(N * gain_diff * decay_diff) lambda_update = sigma * (0.5 * (1 - Sij) - 1 / pos_pairs_score_diff) * delta_ndcg lambda_update = torch.sum(lambda_update, 1, keepdim=True) assert lambda_update.shape == y_pred.shape check_grad = torch.sum(lambda_update, (0, 1)).item() if check_grad == float('inf') or np.isnan(check_grad): import ipdb ipdb.set_trace() grad_batch.append(lambda_update) # optimization is to similar to RankNetListWise, but to maximize NDCG. # lambda_update scales with gain and decay count += 1 if count % batch_size == 0: for grad, y_pred in zip(grad_batch, y_pred_batch): y_pred.backward(grad / batch_size) if count % (4 * batch_size) == 0 and debug: net.dump_param() optimizer.step() net.zero_grad() grad_batch, y_pred_batch = [], [ ] # grad_batch, y_pred_batch used for gradient_acc # optimizer.step() print( get_time(), "training dataset at epoch {}, total queries: {}".format(i, count)) if debug: eval_cross_entropy_loss(net, device, train_loader, i, writer, phase="Train") # eval_ndcg_at_k(net, device, df_train, train_loader, 100000, [10, 30, 50]) if i % 5 == 0 and i != start_epoch: print(get_time(), "eval for epoch: {}".format(i)) eval_cross_entropy_loss(net, device, valid_loader, i, writer) eval_ndcg_at_k(net, device, df_valid, valid_loader, 100000, [10, 30], i, writer) if i % 10 == 0 and i != start_epoch: save_to_ckpt(ckptfile, i, net, optimizer, scheduler) scheduler.step() # save the last ckpt save_to_ckpt(ckptfile, start_epoch + additional_epoch, net, optimizer, scheduler) # save the final model torch.save(net.state_dict(), ckptfile) ndcg_result = eval_ndcg_at_k(net, device, df_valid, valid_loader, 100000, [10, 30], start_epoch + additional_epoch, writer) print( get_time(), "finish training " + ", ".join( ["NDCG@{}: {:.5f}".format(k, ndcg_result[k]) for k in ndcg_result]), '\n\n')
def train_rank_net( start_epoch=0, additional_epoch=100, lr=0.0001, optim="adam", train_algo=SUM_SESSION, double_precision=False, standardize=False, small_dataset=False, debug=False, output_dir="/tmp/ranking_output/", ): """ :param start_epoch: int :param additional_epoch: int :param lr: float :param optim: str :param train_algo: str :param double_precision: boolean :param standardize: boolean :param small_dataset: boolean :param debug: boolean :return: """ print("start_epoch:{}, additional_epoch:{}, lr:{}".format( start_epoch, additional_epoch, lr)) writer = SummaryWriter(output_dir) precision = torch.float64 if double_precision else torch.float32 # get training and validation data: data_fold = 'Fold1' train_loader, df_train, valid_loader, df_valid = load_train_vali_data( data_fold, small_dataset) if standardize: df_train, scaler = train_loader.train_scaler_and_transform() df_valid = valid_loader.apply_scaler(scaler) net, net_inference, ckptfile = get_train_inference_net( train_algo, train_loader.num_features, start_epoch, double_precision) device = get_device() net.to(device) net_inference.to(device) # initialize to make training faster net.apply(init_weights) if optim == "adam": optimizer = torch.optim.Adam(net.parameters(), lr=lr) elif optim == "sgd": optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9) else: raise ValueError( "Optimization method {} not implemented".format(optim)) print(optimizer) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.75) loss_func = None if train_algo == BASELINE: loss_func = torch.nn.BCELoss() loss_func.to(device) losses = [] for i in range(start_epoch, start_epoch + additional_epoch): scheduler.step() net.zero_grad() net.train() if train_algo == BASELINE: epoch_loss = baseline_pairwise_training_loop(i, net, loss_func, optimizer, train_loader, precision=precision, device=device, debug=debug) elif train_algo in [SUM_SESSION, ACC_GRADIENT]: epoch_loss = factorized_training_loop(i, net, None, optimizer, train_loader, training_algo=train_algo, precision=precision, device=device, debug=debug) losses.append(epoch_loss) print('=' * 20 + '\n', get_time(), 'Epoch{}, loss : {}'.format(i, losses[-1]), '\n' + '=' * 20) # save to checkpoint every 5 step, and run eval if i % 5 == 0 and i != start_epoch: save_to_ckpt(ckptfile, i, net, optimizer, scheduler) net_inference.load_state_dict(net.state_dict()) eval_model(net_inference, device, df_valid, valid_loader, i, writer) # save the last ckpt save_to_ckpt(ckptfile, start_epoch + additional_epoch, net, optimizer, scheduler) # final evaluation net_inference.load_state_dict(net.state_dict()) ndcg_result = eval_model(net_inference, device, df_valid, valid_loader, start_epoch + additional_epoch, writer) # save the final model torch.save(net.state_dict(), ckptfile) print( get_time(), "finish training " + ", ".join( ["NDCG@{}: {:.5f}".format(k, ndcg_result[k]) for k in ndcg_result]), '\n\n')