def step(self, batch, stage, mode): assert stage in ('rec', 'conv') assert mode in ('train', 'valid', 'test') for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(self.device) if stage == 'rec': rec_loss, rec_scores = self.model.forward(batch, mode, stage) rec_loss = rec_loss.sum() if mode == 'train': self.backward(rec_loss) else: self.rec_evaluate(rec_scores, batch['item']) rec_loss = rec_loss.item() self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) else: if mode != 'test': gen_loss, preds = self.model.forward(batch, mode, stage) if mode == 'train': self.backward(gen_loss) else: self.conv_evaluate(preds, batch['response']) gen_loss = gen_loss.item() self.evaluator.optim_metrics.add('gen_loss', AverageMetric(gen_loss)) self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: preds = self.model.forward(batch, mode, stage) self.conv_evaluate(preds, batch['response'])
def step(self, batch, stage, mode): ''' converse: context_tokens, context_entities, context_words, response,all_movies = batch recommend context_entities, context_words, entities, movie = batch ''' batch = [ele.to(self.device) for ele in batch] if stage == 'pretrain': info_loss = self.model.forward(batch, stage, mode) if info_loss is not None: self.backward(info_loss.sum()) info_loss = info_loss.sum().item() self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) elif stage == 'rec': rec_loss, info_loss, rec_predict = self.model.forward( batch, stage, mode) if info_loss: loss = rec_loss + 0.025 * info_loss else: loss = rec_loss if mode == "train": self.backward(loss.sum()) else: self.rec_evaluate(rec_predict, batch[-1]) rec_loss = rec_loss.sum().item() self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) if info_loss: info_loss = info_loss.sum().item() self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) elif stage == "conv": if mode != "test": gen_loss, selection_loss, pred = self.model.forward( batch, stage, mode) if mode == 'train': loss = self.gen_loss_weight * gen_loss + selection_loss self.backward(loss.sum()) loss = loss.sum().item() self.evaluator.optim_metrics.add("gen_total_loss", AverageMetric(loss)) gen_loss = gen_loss.sum().item() self.evaluator.optim_metrics.add("gen_loss", AverageMetric(gen_loss)) self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) selection_loss = selection_loss.sum().item() self.evaluator.optim_metrics.add('sel_loss', AverageMetric(selection_loss)) else: pred, matching_pred, matching_logist = self.model.forward( batch, stage, mode) self.conv_evaluate(pred, matching_pred, batch[-2], batch[-1]) else: raise
def step(self, batch, stage, mode): """ stage: ['policy', 'rec', 'conv'] mode: ['train', 'val', 'test] """ batch = [ele.to(self.device) for ele in batch] if stage == 'policy': if mode == 'train': self.policy_model.train() else: self.policy_model.eval() policy_loss, policy_predict = self.policy_model.forward( batch, mode) if mode == "train" and policy_loss is not None: policy_loss = policy_loss.sum() self.backward(policy_loss) else: self.policy_evaluate(policy_predict, batch[-1]) if isinstance(policy_loss, torch.Tensor): policy_loss = policy_loss.item() self.evaluator.optim_metrics.add("policy_loss", AverageMetric(policy_loss)) elif stage == 'rec': if mode == 'train': self.rec_model.train() else: self.rec_model.eval() rec_loss, rec_predict = self.rec_model.forward(batch, mode) rec_loss = rec_loss.sum() if mode == "train": self.backward(rec_loss) else: self.rec_evaluate(rec_predict, batch[-1]) rec_loss = rec_loss.item() self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) elif stage == "conv": if mode != "test": # train + valid: need to compute ppl gen_loss, pred = self.conv_model.forward(batch, mode) gen_loss = gen_loss.sum() if mode == 'train': self.backward(gen_loss) else: self.conv_evaluate(pred, batch[-1]) gen_loss = gen_loss.item() self.evaluator.optim_metrics.add("gen_loss", AverageMetric(gen_loss)) self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: # generate response in conv_model.step pred = self.conv_model.forward(batch, mode) self.conv_evaluate(pred, batch[-1]) else: raise
def step(self, batch, stage, mode): batch = [ele.to(self.device) for ele in batch] if stage == 'pretrain': info_loss = self.model.forward(batch, stage, mode) if info_loss is not None: self.backward(info_loss.sum()) info_loss = info_loss.sum().item() self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) elif stage == 'rec': rec_loss, info_loss, rec_predict = self.model.forward( batch, stage, mode) if info_loss: loss = rec_loss + 0.025 * info_loss else: loss = rec_loss if mode == "train": self.backward(loss.sum()) else: self.rec_evaluate(rec_predict, batch[-1]) rec_loss = rec_loss.sum().item() self.evaluator.optim_metrics.add("rec_loss", AverageMetric(rec_loss)) if info_loss: info_loss = info_loss.sum().item() self.evaluator.optim_metrics.add("info_loss", AverageMetric(info_loss)) elif stage == "conv": if mode != "test": gen_loss, pred = self.model.forward(batch, stage, mode) if mode == 'train': self.backward(gen_loss.sum()) else: self.conv_evaluate(pred, batch[-1]) gen_loss = gen_loss.sum().item() self.evaluator.optim_metrics.add("gen_loss", AverageMetric(gen_loss)) self.evaluator.gen_metrics.add("ppl", PPLMetric(gen_loss)) else: pred = self.model.forward(batch, stage, mode) self.conv_evaluate(pred, batch[-1]) else: raise