Exemple #1
0
    def test(self):
        def f1(a, target):
            TP, FP, FN = 0, 0, 0
            real = target.nonzero().tolist()
            predict = a.nonzero().tolist()
            for item in real:
                if item in predict:
                    TP += 1
                else:
                    FN += 1
            for item in predict:
                if item not in real:
                    FP += 1
            return TP, FP, FN

        a_TP, a_FP, a_FN = 0, 0, 0
        for i, data in enumerate(self.data_test):
            s, target_a = to_device(data)
            a_weights = self.policy(s)
            a = a_weights.ge(0)
            TP, FP, FN = f1(a, target_a)
            a_TP += TP
            a_FP += FP
            a_FN += FN

        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
 def mle_loop(self, data):
     s, target_a = to_device(data)
     a = self.policy.policy(s)
     #for i in range(target_a.size(0)):
     #    target_a[i] = target_a[i] / torch.sum(target_a, 1)[i]
     #a = torch.softmax(a, 1)
     
     loss_a = self.multi_entropy_loss(a, target_a)
     return loss_a
Exemple #3
0
    def irl_loop(self, data_real, data_gen):
        s_real, a_real, next_s_real = to_device(data_real)
        s, a, next_s = data_gen

        # train with real data
        weight_real = self.irl(s_real, a_real, next_s_real)
        loss_real = -weight_real.mean()

        # train with generated data
        weight = self.irl(s, a, next_s)
        loss_gen = weight.mean()
        return loss_real, loss_gen
Exemple #4
0
    def user_loop(self, data):
        batch_input = to_device(padding_data(data))
        a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])

        loss_a, targets_a = 0, batch_input[
            'origin_responses'][:, 1:]  # remove sos_id
        for i, a_weight in enumerate(a_weights):
            loss_a += self.nll_loss(a_weight, targets_a[:, i])
        loss_a /= i
        loss_t = self.bce_loss(t_weights, batch_input['terminated'])
        loss_a += self.alpha * kl_gaussian(argu)
        return loss_a, loss_t
    def discriminator_loop(self, data):
        s, target_a = to_device(data)

        #real
        predict = self.discriminator(torch.cat((s, target_a), 1)).view(-1)
        label = torch.full((predict.size(0),), 1., device=DEVICE)     
        loss_r = self.criterion(predict, label)
        #fake
        a_weights = self.policy.policy(s)
        a = select_action(a_weights, False).float()
        #a = self.policy.policy(s)
        predict = self.discriminator(torch.cat((s, a), 1)).view(-1)
        label = torch.full((predict.size(0),), 0., device=DEVICE)
        loss_f = self.criterion(predict, label)
        return loss_r, loss_f
 def policy_loop(self, data):
     s, target_a = to_device(data)
     a_weights = self.policy.policy(s)
     a = select_action(a_weights, False).float()
     #print(target_a[0])
     #a = self.policy.policy(s)
     #print(a[0])
     predict = self.discriminator(torch.cat((s, a), 1)).view(-1)
     label = torch.full((predict.size(0),), 1., device=DEVICE)
     #print(target_a[0])
     #print(a[0])
     #print(predict)
     loss_p = self.criterion(predict, label)
     #print(loss_p.item())
     return loss_p
Exemple #7
0
    def test(self):
        def sequential(da_seq):
            da = []
            cur_act = None
            for word in da_seq:
                if word in ['<PAD>', '<UNK>', '<SOS>', '<EOS>', '(', ')']:
                    continue
                if '-' in word:
                    cur_act = word
                else:
                    if cur_act is None:
                        continue
                    da.append(cur_act + '-' + word)
            return da

        def f1(pred, real):
            if not real:
                return 0, 0, 0
            TP, FP, FN = 0, 0, 0
            for item in real:
                if item in pred:
                    TP += 1
                else:
                    FN += 1
            for item in pred:
                if item not in real:
                    FP += 1
            return TP, FP, FN

        data_test_iter = batch_iter(self.data_test[0], self.data_test[1],
                                    self.data_test[2], self.data_test[3])
        a_TP, a_FP, a_FN, t_corr, t_tot = 0, 0, 0, 0, 0
        eos_id = self.user.usr_decoder.eos_id
        for i, data in enumerate(data_test_iter):
            batch_input = to_device(padding_data(data))
            a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])
            usr_a = []
            for a_weight in a_weights:
                usr_a.append(a_weight.argmax(1).cpu().numpy())
            usr_a = np.array(usr_a).T.tolist()
            a = []
            for ua in usr_a:
                if eos_id in ua:
                    ua = ua[:ua.index(eos_id)]
                a.append(sequential(self.manager.id2sentence(ua)))
            targets_a = []
            for ua_sess in data[1]:
                for ua in ua_sess:
                    targets_a.append(
                        sequential(self.manager.id2sentence(ua[1:-1])))
            TP, FP, FN = f1(a, targets_a)
            a_TP += TP
            a_FP += FP
            a_FN += FN

            t = t_weights.ge(0).cpu().tolist()
            targets_t = batch_input['terminated'].cpu().long().tolist()
            judge = np.array(t) == np.array(targets_t)
            t_corr += judge.sum()
            t_tot += judge.size

        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
        print(t_corr, t_tot, t_corr / t_tot)
Exemple #8
0
    def policy_loop(self, data):
        s, target_a = to_device(data)
        a_weights = self.policy(s)

        loss_a = self.multi_entropy_loss(a_weights, target_a)
        return loss_a