def train_loop(FLAGS,
               model,
               trainer,
               train_dataset,
               eval_datasets,
               entity_total,
               relation_total,
               logger,
               vis=None,
               is_report=False):
    train_iter, train_total, train_list, train_head_dict, train_tail_dict = train_dataset

    all_head_dicts = None
    all_tail_dicts = None
    if FLAGS.filter_wrong_corrupted:
        all_head_dicts = [train_head_dict
                          ] + [tmp_data[4] for tmp_data in eval_datasets]
        all_tail_dicts = [train_tail_dict
                          ] + [tmp_data[5] for tmp_data in eval_datasets]

    # Train.
    logger.info("Training.")

    # New Training Loop
    pbar = None
    total_loss = 0.0
    model.enable_grad()
    for _ in range(trainer.step, FLAGS.training_steps):

        if FLAGS.early_stopping_steps_to_wait > 0 and (
                trainer.step -
                trainer.best_step) > FLAGS.early_stopping_steps_to_wait:
            logger.info('No improvement after ' +
                        str(FLAGS.early_stopping_steps_to_wait) +
                        ' steps. Stopping training.')
            if pbar is not None: pbar.close()
            break
        if trainer.step % FLAGS.eval_interval_steps == 0:
            if pbar is not None:
                pbar.close()
            total_loss /= FLAGS.eval_interval_steps
            logger.info("train loss:{:.4f}!".format(total_loss))

            performances = []
            for i, eval_data in enumerate(eval_datasets):
                eval_head_dicts = None
                eval_tail_dicts = None
                if FLAGS.filter_wrong_corrupted:
                    eval_head_dicts = [train_head_dict] + [
                        tmp_data[4]
                        for j, tmp_data in enumerate(eval_datasets) if j != i
                    ]
                    eval_tail_dicts = [train_tail_dict] + [
                        tmp_data[5]
                        for j, tmp_data in enumerate(eval_datasets) if j != i
                    ]

                performances.append(
                    evaluate(FLAGS,
                             model,
                             entity_total,
                             relation_total,
                             eval_data[0],
                             eval_data[1],
                             eval_data[4],
                             eval_data[5],
                             eval_head_dicts,
                             eval_tail_dicts,
                             logger,
                             eval_descending=False,
                             is_report=is_report))

            if trainer.step > 0:
                is_best = trainer.new_performance(performances[0],
                                                  performances)
                # visuliazation
                if vis is not None:
                    vis.plot_many_stack({'KG Train Loss': total_loss},
                                        win_name="Loss Curve")
                    hit_vis_dict = {}
                    meanrank_vis_dict = {}
                    for i, performance in enumerate(performances):
                        hit_vis_dict['KG Eval {} Hit'.format(
                            i)] = performance[0]
                        meanrank_vis_dict['KG Eval {} MeanRank'.format(
                            i)] = performance[1]

                    if is_best:
                        log_str = [
                            "Best performances in {} step!".format(
                                trainer.best_step)
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % hit_vis_dict[s])
                            for s in hit_vis_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % meanrank_vis_dict[s])
                            for s in meanrank_vis_dict
                        ]
                        vis.log("\n".join(log_str),
                                win_name="Best Performances")

                    vis.plot_many_stack(hit_vis_dict,
                                        win_name="KG Hit Ratio@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(meanrank_vis_dict,
                                        win_name="KG MeanRank")
            # set model in training mode
            pbar = tqdm(total=FLAGS.eval_interval_steps)
            pbar.set_description("Training")
            total_loss = 0.0
            model.train()
            model.enable_grad()

        triple_batch = next(train_iter)
        ph, pt, pr, nh, nt, nr = getTrainTripleBatch(
            triple_batch,
            entity_total,
            all_head_dicts=all_head_dicts,
            all_tail_dicts=all_tail_dicts)

        ph_var = to_gpu(V(torch.LongTensor(ph)))
        pt_var = to_gpu(V(torch.LongTensor(pt)))
        pr_var = to_gpu(V(torch.LongTensor(pr)))
        nh_var = to_gpu(V(torch.LongTensor(nh)))
        nt_var = to_gpu(V(torch.LongTensor(nt)))
        nr_var = to_gpu(V(torch.LongTensor(nr)))

        trainer.optimizer_zero_grad()

        # Run model. output: batch_size * 1
        pos_score = model(ph_var, pt_var, pr_var)
        neg_score = model(nh_var, nt_var, nr_var)

        # Calculate loss.
        # losses = nn.MarginRankingLoss(margin=FLAGS.margin).forward(pos_score, neg_score, to_gpu(torch.autograd.Variable(torch.FloatTensor([trainer.model_target]*len(ph)))))

        losses = loss.marginLoss()(pos_score, neg_score, FLAGS.margin)

        ent_embeddings = model.ent_embeddings(
            torch.cat([ph_var, pt_var, nh_var, nt_var]))
        rel_embeddings = model.rel_embeddings(torch.cat([pr_var, nr_var]))

        if FLAGS.model_type == "transh":
            norm_embeddings = model.norm_embeddings(torch.cat([pr_var,
                                                               nr_var]))
            losses += loss.orthogonalLoss(rel_embeddings, norm_embeddings)

        losses = losses + loss.normLoss(ent_embeddings) + loss.normLoss(
            rel_embeddings)

        # Backward pass.
        losses.backward()

        # for param in model.parameters():
        #     print(param.grad.data.sum())

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm(
            [param for name, param in model.named_parameters()],
            FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()
        total_loss += losses.data[0]
        pbar.update(1)
def train_loop(FLAGS,
               model,
               trainer,
               rating_train_dataset,
               triple_train_dataset,
               rating_eval_datasets,
               triple_eval_datasets,
               e_map,
               i_map,
               ikg_map,
               logger,
               vis=None,
               is_report=False):
    rating_train_iter, rating_train_total, rating_train_list, rating_train_dict = rating_train_dataset

    triple_train_iter, triple_train_total, triple_train_list, head_train_dict, tail_train_dict = triple_train_dataset

    all_rating_dicts = None
    if FLAGS.filter_wrong_corrupted:
        all_rating_dicts = [rating_train_dict] + [
            tmp_data[3] for tmp_data in rating_eval_datasets
        ]

    all_head_dicts = None
    all_tail_dicts = None
    if FLAGS.filter_wrong_corrupted:
        all_head_dicts = [head_train_dict] + [
            tmp_data[4] for tmp_data in triple_eval_datasets
        ]
        all_tail_dicts = [tail_train_dict] + [
            tmp_data[5] for tmp_data in triple_eval_datasets
        ]

    item_total = len(i_map)
    entity_total = len(e_map)
    step_to_switch = 10 * FLAGS.joint_ratio

    # Train.
    logger.info("Training.")

    # New Training Loop
    pbar = None
    rec_total_loss = 0.0
    kg_total_loss = 0.0
    model.train()
    model.enable_grad()
    for _ in range(trainer.step, FLAGS.training_steps):

        if FLAGS.early_stopping_steps_to_wait > 0 and (
                trainer.step -
                trainer.best_step) > FLAGS.early_stopping_steps_to_wait:
            logger.info('No improvement after ' +
                        str(FLAGS.early_stopping_steps_to_wait) +
                        ' steps. Stopping training.')
            if pbar is not None: pbar.close()
            break
        if trainer.step % FLAGS.eval_interval_steps == 0:
            if pbar is not None:
                pbar.close()
            rec_total_loss /= (FLAGS.eval_interval_steps * FLAGS.joint_ratio)
            kg_total_loss /= (FLAGS.eval_interval_steps *
                              (1 - FLAGS.joint_ratio))
            logger.info("rec train loss:{:.4f}, kg train loss:{:.4f}!".format(
                rec_total_loss, kg_total_loss))

            rec_performances = []
            for i, eval_data in enumerate(rating_eval_datasets):
                all_eval_dicts = None
                if FLAGS.filter_wrong_corrupted:
                    all_eval_dicts = [rating_train_dict] + [
                        tmp_data[3]
                        for j, tmp_data in enumerate(rating_eval_datasets)
                        if j != i
                    ]

                rec_performances.append(
                    evaluateRec(FLAGS,
                                model,
                                eval_data[0],
                                eval_data[3],
                                all_eval_dicts,
                                i_map,
                                logger,
                                eval_descending=True
                                if trainer.model_target == 1 else False,
                                is_report=is_report))

            kg_performances = []
            for i, eval_data in enumerate(triple_eval_datasets):
                eval_head_dicts = None
                eval_tail_dicts = None
                if FLAGS.filter_wrong_corrupted:
                    eval_head_dicts = [head_train_dict] + [
                        tmp_data[4]
                        for j, tmp_data in enumerate(triple_eval_datasets)
                        if j != i
                    ]
                    eval_tail_dicts = [tail_train_dict] + [
                        tmp_data[5]
                        for j, tmp_data in enumerate(triple_eval_datasets)
                        if j != i
                    ]

                kg_performances.append(
                    evaluateKG(FLAGS,
                               model,
                               eval_data[0],
                               eval_data[1],
                               eval_data[4],
                               eval_data[5],
                               eval_head_dicts,
                               eval_tail_dicts,
                               e_map,
                               logger,
                               eval_descending=False,
                               is_report=is_report))

            if trainer.step > 0:
                is_best = trainer.new_performance(kg_performances[0],
                                                  kg_performances)
                # visuliazation
                if vis is not None:
                    vis.plot_many_stack(
                        {
                            'Rec Train Loss': rec_total_loss,
                            'KG Train Loss': kg_total_loss
                        },
                        win_name="Loss Curve")

                    f1_dict = {}
                    p_dict = {}
                    r_dict = {}
                    rec_hit_dict = {}
                    ndcg_dict = {}
                    for i, performance in enumerate(rec_performances):
                        f1_dict['Rec Eval {} F1'.format(i)] = performance[0]
                        p_dict['Rec Eval {} Precision'.format(
                            i)] = performance[1]
                        r_dict['Rec Eval {} Recall'.format(i)] = performance[2]
                        rec_hit_dict['Rec Eval {} Hit'.format(
                            i)] = performance[3]
                        ndcg_dict['Rec Eval {} NDCG'.format(
                            i)] = performance[4]

                    kg_hit_dict = {}
                    meanrank_dict = {}
                    for i, performance in enumerate(kg_performances):
                        kg_hit_dict['KG Eval {} Hit'.format(
                            i)] = performance[0]
                        meanrank_dict['KG Eval {} MeanRank'.format(
                            i)] = performance[1]

                    if is_best:
                        log_str = [
                            "Best performances in {} step!".format(
                                trainer.best_step)
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % f1_dict[s])
                            for s in f1_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % p_dict[s])
                            for s in p_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % r_dict[s])
                            for s in r_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % rec_hit_dict[s])
                            for s in rec_hit_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % ndcg_dict[s])
                            for s in ndcg_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % kg_hit_dict[s])
                            for s in kg_hit_dict
                        ]
                        log_str += [
                            "{} : {}.".format(s, "%.5f" % meanrank_dict[s])
                            for s in meanrank_dict
                        ]

                        vis.log("\n".join(log_str),
                                win_name="Best Performances")

                    vis.plot_many_stack(f1_dict,
                                        win_name="Rec F1 Score@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(p_dict,
                                        win_name="Rec Precision@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(r_dict,
                                        win_name="Rec Recall@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(rec_hit_dict,
                                        win_name="Rec Hit Ratio@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(ndcg_dict,
                                        win_name="Rec NDCG@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(kg_hit_dict,
                                        win_name="KG Hit Ratio@{}".format(
                                            FLAGS.topn))

                    vis.plot_many_stack(meanrank_dict, win_name="KG MeanRank")

            # set model in training mode
            pbar = tqdm(total=FLAGS.eval_interval_steps)
            pbar.set_description("Training")
            rec_total_loss = 0.0
            kg_total_loss = 0.0

            model.train()
            model.enable_grad()

        # recommendation train
        if trainer.step % 10 < step_to_switch:
            rating_batch = next(rating_train_iter)
            u, pi, ni = getNegRatings(rating_batch,
                                      item_total,
                                      all_dicts=all_rating_dicts)

            e_ids, i_ids = getMappedEntities(pi + ni, i_map, ikg_map)

            if FLAGS.share_embeddings:
                ni = [i_map[i] for i in ni]
                pi = [i_map[i] for i in pi]

            u_var = to_gpu(V(torch.LongTensor(u)))
            pi_var = to_gpu(V(torch.LongTensor(pi)))
            ni_var = to_gpu(V(torch.LongTensor(ni)))

            trainer.optimizer_zero_grad()

            # Run model. output: batch_size * cand_num, input: ratings, triples, is_rec=True
            pos_score = model((u_var, pi_var), None, is_rec=True)
            neg_score = model((u_var, ni_var), None, is_rec=True)

            # Calculate loss.
            losses = bprLoss(pos_score, neg_score, target=trainer.model_target)

            if FLAGS.model_type in ["transup", "jtransup"]:
                losses += orthogonalLoss(model.pref_embeddings.weight,
                                         model.pref_norm_embeddings.weight)
        # kg train
        else:
            triple_batch = next(triple_train_iter)
            ph, pt, pr, nh, nt, nr = getTrainTripleBatch(
                triple_batch,
                entity_total,
                all_head_dicts=all_head_dicts,
                all_tail_dicts=all_tail_dicts)

            e_ids, i_ids = getMappedItems(ph + pt + nh + nt, e_map, ikg_map)

            if FLAGS.share_embeddings:
                ph = [e_map[e] for e in ph]
                pt = [e_map[e] for e in pt]
                nh = [e_map[e] for e in nh]
                nt = [e_map[e] for e in nt]

            ph_var = to_gpu(V(torch.LongTensor(ph)))
            pt_var = to_gpu(V(torch.LongTensor(pt)))
            pr_var = to_gpu(V(torch.LongTensor(pr)))
            nh_var = to_gpu(V(torch.LongTensor(nh)))
            nt_var = to_gpu(V(torch.LongTensor(nt)))
            nr_var = to_gpu(V(torch.LongTensor(nr)))

            trainer.optimizer_zero_grad()

            # Run model. output: batch_size * cand_nu, input: ratings, triples, is_rec=True
            pos_score = model(None, (ph_var, pt_var, pr_var), is_rec=False)
            neg_score = model(None, (nh_var, nt_var, nr_var), is_rec=False)

            # Calculate loss.
            # losses = nn.MarginRankingLoss(margin=FLAGS.margin).forward(pos_score, neg_score, to_gpu(torch.autograd.Variable(torch.FloatTensor([trainer.model_target]*len(ph)))))

            losses = loss.marginLoss()(pos_score, neg_score, FLAGS.margin)

            ent_embeddings = model.ent_embeddings(
                torch.cat([ph_var, pt_var, nh_var, nt_var]))
            rel_embeddings = model.rel_embeddings(torch.cat([pr_var, nr_var]))
            if FLAGS.model_type in ["jtransup"]:
                norm_embeddings = model.norm_embeddings(
                    torch.cat([pr_var, nr_var]))
                losses += loss.orthogonalLoss(rel_embeddings, norm_embeddings)

            losses = losses + loss.normLoss(ent_embeddings) + loss.normLoss(
                rel_embeddings)
            losses = FLAGS.kg_lambda * losses
        # align loss if not share embeddings
        if not FLAGS.share_embeddings and FLAGS.model_type not in [
                'cke', 'jtransup'
        ]:
            e_var = to_gpu(V(torch.LongTensor(e_ids)))
            i_var = to_gpu(V(torch.LongTensor(i_ids)))
            ent_embeddings = model.ent_embeddings(e_var)
            item_embeddings = model.item_embeddings(i_var)
            losses += FLAGS.norm_lambda * loss.pNormLoss(
                ent_embeddings, item_embeddings, L1_flag=FLAGS.L1_flag)

        # Backward pass.
        losses.backward()

        # for param in model.parameters():
        #     print(param.grad.data.sum())

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm(
            [param for name, param in model.named_parameters()],
            FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()
        if trainer.step % 10 < step_to_switch:
            rec_total_loss += losses.data[0]
        else:
            kg_total_loss += losses.data[0]
        pbar.update(1)
Esempio n. 3
0
def train_loop(FLAGS,
               model,
               trainer,
               train_dataset,
               eval_datasets,
               user_total,
               item_total,
               logger,
               vis=None,
               is_report=False):
    train_iter, train_total, train_list, train_dict = train_dataset

    all_dicts = None
    if FLAGS.filter_wrong_corrupted:
        all_dicts = [train_dict] + [tmp_data[3] for tmp_data in eval_datasets]

    # Train.
    logger.info("Training.")

    # New Training Loop
    pbar = None
    total_loss = 0.0
    model.train()
    model.enable_grad()
    for _ in range(trainer.step, FLAGS.training_steps):

        # if FLAGS.early_stopping_steps_to_wait > 0 and (trainer.step - trainer.best_step) > FLAGS.early_stopping_steps_to_wait:
        #     logger.info('No improvement after ' +
        #                str(FLAGS.early_stopping_steps_to_wait) +
        #                ' steps. Stopping training.')
        #     if pbar is not None: pbar.close()
        #     break
        if trainer.step % FLAGS.eval_interval_steps == 0:
            if pbar is not None:
                pbar.close()
            total_loss /= FLAGS.eval_interval_steps
            logger.info("train loss:{:.4f}!".format(total_loss))

            # performances = []
            # for i, eval_data in enumerate(eval_datasets):
            #     all_eval_dicts = None
            #     if FLAGS.filter_wrong_corrupted:
            #         all_eval_dicts = [train_dict] + [tmp_data[3] for j, tmp_data in enumerate(eval_datasets) if j!=i]

            #     performances.append( evaluate(FLAGS, model, eval_data[0], eval_data[3], all_eval_dicts, logger, eval_descending=True if trainer.model_target == 1 else False, is_report=is_report))

            # if trainer.step > 0 and len(performances) > 0:
            #     is_best = trainer.new_performance(performances[0], performances)

            # # visuliazation
            # if vis is not None:
            #     vis.plot_many_stack({'Rec Train Loss': total_loss},
            #     win_name="Loss Curve")
            #     f1_vis_dict = {}
            #     p_vis_dict = {}
            #     r_vis_dict = {}
            #     hit_vis_dict = {}
            #     ndcg_vis_dict = {}
            #     for i, performance in enumerate(performances):
            #         f1_vis_dict['Rec Eval {} F1'.format(i)] = performance[0]
            #         p_vis_dict['Rec Eval {} Precision'.format(i)] = performance[1]
            #         r_vis_dict['Rec Eval {} Recall'.format(i)] = performance[2]
            #         hit_vis_dict['Rec Eval {} Hit'.format(i)] = performance[3]
            #         ndcg_vis_dict['Rec Eval {} NDCG'.format(i)] = performance[4]

            #     if is_best:
            #         log_str = ["Best performances in {} step!".format(trainer.best_step)]
            #         log_str += ["{} : {}.".format(s, "%.5f" % f1_vis_dict[s]) for s in f1_vis_dict]
            #         log_str += ["{} : {}.".format(s, "%.5f" % p_vis_dict[s]) for s in p_vis_dict]
            #         log_str += ["{} : {}.".format(s, "%.5f" % r_vis_dict[s]) for s in r_vis_dict]
            #         log_str += ["{} : {}.".format(s, "%.5f" % hit_vis_dict[s]) for s in hit_vis_dict]
            #         log_str += ["{} : {}.".format(s, "%.5f" % ndcg_vis_dict[s]) for s in ndcg_vis_dict]

            #         vis.log("\n".join(log_str), win_name="Best Performances")

            #     vis.plot_many_stack(f1_vis_dict, win_name="Rec F1 Score@{}".format(FLAGS.topn))

            #     vis.plot_many_stack(p_vis_dict, win_name="Rec Precision@{}".format(FLAGS.topn))

            #     vis.plot_many_stack(r_vis_dict, win_name="Rec Recall@{}".format(FLAGS.topn))

            #     vis.plot_many_stack(hit_vis_dict, win_name="Rec Hit Ratio@{}".format(FLAGS.topn))

            #     vis.plot_many_stack(ndcg_vis_dict, win_name="Rec NDCG@{}".format(FLAGS.topn))

            # set model in training mode
            pbar = tqdm(total=FLAGS.eval_interval_steps)
            pbar.set_description("Training")
            total_loss = 0.0
            model.train()
            model.enable_grad()

        rating_batch = next(train_iter)
        u, pi, ni = getNegRatings(rating_batch,
                                  item_total,
                                  all_dicts=all_dicts)

        u_var = to_gpu(V(torch.LongTensor(u)))
        pi_var = to_gpu(V(torch.LongTensor(pi)))
        ni_var = to_gpu(V(torch.LongTensor(ni)))

        trainer.optimizer_zero_grad()

        # Run model. output: batch_size * cand_num
        pos_score = model(u_var, pi_var)
        neg_score = model(u_var, ni_var)

        # Calculate loss.
        losses = bprLoss(pos_score, neg_score, target=trainer.model_target)

        if FLAGS.model_type in ["transup", "transupb"]:
            user_embeddings = model.user_embeddings(u_var)
            item_embeddings = model.item_embeddings(torch.cat([pi_var,
                                                               ni_var]))
            losses += orthogonalLoss(
                model.pref_embeddings.weight,
                model.pref_norm_embeddings.weight) + normLoss(
                    user_embeddings) + normLoss(item_embeddings) + normLoss(
                        model.pref_embeddings.weight)

        # Backward pass.
        losses.backward()

        # for param in model.parameters():
        #     print(param.grad.data.sum())

        # Hard Gradient Clipping
        nn.utils.clip_grad_norm(
            [param for name, param in model.named_parameters()],
            FLAGS.clipping_max_value)

        # Gradient descent step.
        trainer.optimizer_step()
        total_loss += losses.item()
        pbar.update(1)