def user_loop(self, data): batch_input = to_device(padding_data(data)) a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \ batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses']) loss_a, targets_a = 0, batch_input['origin_responses'][:, 1:] # remove sos_id for i, a_weight in enumerate(a_weights): loss_a += self.nll_loss(a_weight, targets_a[:, i]) loss_a /= i loss_t = self.bce_loss(t_weights, batch_input['terminal']) loss_a += self.alpha * kl_gaussian(argu) return loss_a, loss_t
def test(self): def sequential(da_seq): da = [] cur_act = None for word in da_seq: if word in ['<PAD>', '<UNK>', '<SOS>', '<EOS>', '(', ')']: continue if '-' in word: cur_act = word else: if cur_act is None: continue da.append(cur_act+'-'+word) return da def f1(pred, real): if not real: return 0, 0, 0 TP, FP, FN = 0, 0, 0 for item in real: if item in pred: TP += 1 else: FN += 1 for item in pred: if item not in real: FP += 1 return TP, FP, FN data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3]) a_TP, a_FP, a_FN, t_corr, t_tot = 0, 0, 0, 0, 0 eos_id = self.user.usr_decoder.eos_id for i, data in enumerate(data_test_iter): batch_input = to_device(padding_data(data)) a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \ batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses']) usr_a = [] for a_weight in a_weights: usr_a.append(a_weight.argmax(1).cpu().numpy()) usr_a = np.array(usr_a).T.tolist() a = [] for ua in usr_a: if eos_id in ua: ua = ua[:ua.index(eos_id)] a.append(sequential(self.manager.id2sentence(ua))) targets_a = [] for ua_sess in data[1]: for ua in ua_sess: targets_a.append(sequential(self.manager.id2sentence(ua[1:-1]))) TP, FP, FN = f1(a, targets_a) a_TP += TP a_FP += FP a_FN += FN t = t_weights.ge(0).cpu().tolist() targets_t = batch_input['terminal'].cpu().long().tolist() judge = np.array(t) == np.array(targets_t) t_corr += judge.sum() t_tot += judge.size prec = a_TP / (a_TP + a_FP) rec = a_TP / (a_TP + a_FN) F1 = 2 * prec * rec / (prec + rec) print(a_TP, a_FP, a_FN, F1) print(t_corr, t_tot, t_corr/t_tot)