def test(self): def f1(a, target): TP, FP, FN = 0, 0, 0 real = target.nonzero().tolist() predict = a.nonzero().tolist() for item in real: if item in predict: TP += 1 else: FN += 1 for item in predict: if item not in real: FP += 1 return TP, FP, FN a_TP, a_FP, a_FN = 0, 0, 0 for i, data in enumerate(self.data_test): s, target_a = to_device(data) a_weights = self.policy(s) a = a_weights.ge(0) TP, FP, FN = f1(a, target_a) a_TP += TP a_FP += FP a_FN += FN prec = a_TP / (a_TP + a_FP) rec = a_TP / (a_TP + a_FN) F1 = 2 * prec * rec / (prec + rec) print(a_TP, a_FP, a_FN, F1)
def mle_loop(self, data): s, target_a = to_device(data) a = self.policy.policy(s) #for i in range(target_a.size(0)): # target_a[i] = target_a[i] / torch.sum(target_a, 1)[i] #a = torch.softmax(a, 1) loss_a = self.multi_entropy_loss(a, target_a) return loss_a
def irl_loop(self, data_real, data_gen): s_real, a_real, next_s_real = to_device(data_real) s, a, next_s = data_gen # train with real data weight_real = self.irl(s_real, a_real, next_s_real) loss_real = -weight_real.mean() # train with generated data weight = self.irl(s, a, next_s) loss_gen = weight.mean() return loss_real, loss_gen
def user_loop(self, data): batch_input = to_device(padding_data(data)) a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \ batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses']) loss_a, targets_a = 0, batch_input[ 'origin_responses'][:, 1:] # remove sos_id for i, a_weight in enumerate(a_weights): loss_a += self.nll_loss(a_weight, targets_a[:, i]) loss_a /= i loss_t = self.bce_loss(t_weights, batch_input['terminated']) loss_a += self.alpha * kl_gaussian(argu) return loss_a, loss_t
def discriminator_loop(self, data): s, target_a = to_device(data) #real predict = self.discriminator(torch.cat((s, target_a), 1)).view(-1) label = torch.full((predict.size(0),), 1., device=DEVICE) loss_r = self.criterion(predict, label) #fake a_weights = self.policy.policy(s) a = select_action(a_weights, False).float() #a = self.policy.policy(s) predict = self.discriminator(torch.cat((s, a), 1)).view(-1) label = torch.full((predict.size(0),), 0., device=DEVICE) loss_f = self.criterion(predict, label) return loss_r, loss_f
def policy_loop(self, data): s, target_a = to_device(data) a_weights = self.policy.policy(s) a = select_action(a_weights, False).float() #print(target_a[0]) #a = self.policy.policy(s) #print(a[0]) predict = self.discriminator(torch.cat((s, a), 1)).view(-1) label = torch.full((predict.size(0),), 1., device=DEVICE) #print(target_a[0]) #print(a[0]) #print(predict) loss_p = self.criterion(predict, label) #print(loss_p.item()) return loss_p
def test(self): def sequential(da_seq): da = [] cur_act = None for word in da_seq: if word in ['<PAD>', '<UNK>', '<SOS>', '<EOS>', '(', ')']: continue if '-' in word: cur_act = word else: if cur_act is None: continue da.append(cur_act + '-' + word) return da def f1(pred, real): if not real: return 0, 0, 0 TP, FP, FN = 0, 0, 0 for item in real: if item in pred: TP += 1 else: FN += 1 for item in pred: if item not in real: FP += 1 return TP, FP, FN data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3]) a_TP, a_FP, a_FN, t_corr, t_tot = 0, 0, 0, 0, 0 eos_id = self.user.usr_decoder.eos_id for i, data in enumerate(data_test_iter): batch_input = to_device(padding_data(data)) a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \ batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses']) usr_a = [] for a_weight in a_weights: usr_a.append(a_weight.argmax(1).cpu().numpy()) usr_a = np.array(usr_a).T.tolist() a = [] for ua in usr_a: if eos_id in ua: ua = ua[:ua.index(eos_id)] a.append(sequential(self.manager.id2sentence(ua))) targets_a = [] for ua_sess in data[1]: for ua in ua_sess: targets_a.append( sequential(self.manager.id2sentence(ua[1:-1]))) TP, FP, FN = f1(a, targets_a) a_TP += TP a_FP += FP a_FN += FN t = t_weights.ge(0).cpu().tolist() targets_t = batch_input['terminated'].cpu().long().tolist() judge = np.array(t) == np.array(targets_t) t_corr += judge.sum() t_tot += judge.size prec = a_TP / (a_TP + a_FP) rec = a_TP / (a_TP + a_FN) F1 = 2 * prec * rec / (prec + rec) print(a_TP, a_FP, a_FN, F1) print(t_corr, t_tot, t_corr / t_tot)
def policy_loop(self, data): s, target_a = to_device(data) a_weights = self.policy(s) loss_a = self.multi_entropy_loss(a_weights, target_a) return loss_a