Beispiel #1
0
    def step(self, batch, stage, mode):
        assert stage in ('rec', 'conv')
        assert mode in ('train', 'valid', 'test')

        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(self.device)

        if stage == 'rec':
            rec_loss, rec_scores = self.model.forward(batch, mode, stage)
            rec_loss = rec_loss.sum()
            if mode == 'train':
                self.backward(rec_loss)
            else:
                self.rec_evaluate(rec_scores, batch['item'])
            rec_loss = rec_loss.item()
            self.evaluator.optim_metrics.add("rec_loss",
                                             AverageMetric(rec_loss))
        else:
            if mode != 'test':
                gen_loss, preds = self.model.forward(batch, mode, stage)
                if mode == 'train':
                    self.backward(gen_loss)
                else:
                    self.conv_evaluate(preds, batch['response'])
                gen_loss = gen_loss.item()
                self.evaluator.optim_metrics.add('gen_loss',
                                                 AverageMetric(gen_loss))
                self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
            else:
                preds = self.model.forward(batch, mode, stage)
                self.conv_evaluate(preds, batch['response'])
Beispiel #2
0
    def step(self, batch, stage, mode):
        '''
        converse:
        context_tokens, context_entities, context_words, response,all_movies = batch

        recommend
        context_entities, context_words, entities, movie = batch
        '''
        batch = [ele.to(self.device) for ele in batch]
        if stage == 'pretrain':
            info_loss = self.model.forward(batch, stage, mode)
            if info_loss is not None:
                self.backward(info_loss.sum())
                info_loss = info_loss.sum().item()
                self.evaluator.optim_metrics.add("info_loss",
                                                 AverageMetric(info_loss))
        elif stage == 'rec':
            rec_loss, info_loss, rec_predict = self.model.forward(
                batch, stage, mode)
            if info_loss:
                loss = rec_loss + 0.025 * info_loss
            else:
                loss = rec_loss
            if mode == "train":
                self.backward(loss.sum())
            else:
                self.rec_evaluate(rec_predict, batch[-1])
            rec_loss = rec_loss.sum().item()
            self.evaluator.optim_metrics.add("rec_loss",
                                             AverageMetric(rec_loss))
            if info_loss:
                info_loss = info_loss.sum().item()
                self.evaluator.optim_metrics.add("info_loss",
                                                 AverageMetric(info_loss))
        elif stage == "conv":
            if mode != "test":
                gen_loss, selection_loss, pred = self.model.forward(
                    batch, stage, mode)
                if mode == 'train':
                    loss = self.gen_loss_weight * gen_loss + selection_loss
                    self.backward(loss.sum())
                    loss = loss.sum().item()
                    self.evaluator.optim_metrics.add("gen_total_loss",
                                                     AverageMetric(loss))
                gen_loss = gen_loss.sum().item()

                self.evaluator.optim_metrics.add("gen_loss",
                                                 AverageMetric(gen_loss))
                self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
                selection_loss = selection_loss.sum().item()
                self.evaluator.optim_metrics.add('sel_loss',
                                                 AverageMetric(selection_loss))

            else:
                pred, matching_pred, matching_logist = self.model.forward(
                    batch, stage, mode)
                self.conv_evaluate(pred, matching_pred, batch[-2], batch[-1])
        else:
            raise
Beispiel #3
0
    def step(self, batch, stage, mode):
        """
        stage: ['policy', 'rec', 'conv']
        mode: ['train', 'val', 'test]
        """
        batch = [ele.to(self.device) for ele in batch]
        if stage == 'policy':
            if mode == 'train':
                self.policy_model.train()
            else:
                self.policy_model.eval()

            policy_loss, policy_predict = self.policy_model.forward(
                batch, mode)
            if mode == "train" and policy_loss is not None:
                policy_loss = policy_loss.sum()
                self.backward(policy_loss)
            else:
                self.policy_evaluate(policy_predict, batch[-1])
            if isinstance(policy_loss, torch.Tensor):
                policy_loss = policy_loss.item()
                self.evaluator.optim_metrics.add("policy_loss",
                                                 AverageMetric(policy_loss))
        elif stage == 'rec':
            if mode == 'train':
                self.rec_model.train()
            else:
                self.rec_model.eval()
            rec_loss, rec_predict = self.rec_model.forward(batch, mode)
            rec_loss = rec_loss.sum()
            if mode == "train":
                self.backward(rec_loss)
            else:
                self.rec_evaluate(rec_predict, batch[-1])
            rec_loss = rec_loss.item()
            self.evaluator.optim_metrics.add("rec_loss",
                                             AverageMetric(rec_loss))
        elif stage == "conv":
            if mode != "test":
                # train + valid: need to compute ppl
                gen_loss, pred = self.conv_model.forward(batch, mode)
                gen_loss = gen_loss.sum()
                if mode == 'train':
                    self.backward(gen_loss)
                else:
                    self.conv_evaluate(pred, batch[-1])
                gen_loss = gen_loss.item()
                self.evaluator.optim_metrics.add("gen_loss",
                                                 AverageMetric(gen_loss))
                self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
            else:
                # generate response in conv_model.step
                pred = self.conv_model.forward(batch, mode)
                self.conv_evaluate(pred, batch[-1])
        else:
            raise
Beispiel #4
0
 def step(self, batch, stage, mode):
     batch = [ele.to(self.device) for ele in batch]
     if stage == 'pretrain':
         info_loss = self.model.forward(batch, stage, mode)
         if info_loss is not None:
             self.backward(info_loss.sum())
             info_loss = info_loss.sum().item()
             self.evaluator.optim_metrics.add("info_loss",
                                              AverageMetric(info_loss))
     elif stage == 'rec':
         rec_loss, info_loss, rec_predict = self.model.forward(
             batch, stage, mode)
         if info_loss:
             loss = rec_loss + 0.025 * info_loss
         else:
             loss = rec_loss
         if mode == "train":
             self.backward(loss.sum())
         else:
             self.rec_evaluate(rec_predict, batch[-1])
         rec_loss = rec_loss.sum().item()
         self.evaluator.optim_metrics.add("rec_loss",
                                          AverageMetric(rec_loss))
         if info_loss:
             info_loss = info_loss.sum().item()
             self.evaluator.optim_metrics.add("info_loss",
                                              AverageMetric(info_loss))
     elif stage == "conv":
         if mode != "test":
             gen_loss, pred = self.model.forward(batch, stage, mode)
             if mode == 'train':
                 self.backward(gen_loss.sum())
             else:
                 self.conv_evaluate(pred, batch[-1])
             gen_loss = gen_loss.sum().item()
             self.evaluator.optim_metrics.add("gen_loss",
                                              AverageMetric(gen_loss))
             self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss))
         else:
             pred = self.model.forward(batch, stage, mode)
             self.conv_evaluate(pred, batch[-1])
     else:
         raise