コード例 #1
0
    def test(self):
        def f1(a, target):
            TP, FP, FN = 0, 0, 0
            real = target.nonzero().tolist()
            predict = a.nonzero().tolist()
            # print(real)
            # print(predict)
            # print()
            for item in real:
                if item in predict:
                    TP += 1
                else:
                    FN += 1
            for item in predict:
                if item not in real:
                    FP += 1
            return TP, FP, FN

        a_TP, a_FP, a_FN = 0, 0, 0
        for i, data in enumerate(self.data_test):
            s, target_a = to_device(data)
            a_weights = self.policy(s)
            a = a_weights.ge(0)
            # TODO: fix batch F1
            TP, FP, FN = f1(a, target_a)
            a_TP += TP
            a_FP += FP
            a_FN += FN

        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
コード例 #2
0
    def irl_loop(self, data_real, data_gen):
        s_real, a_real, next_s_real = to_device(data_real)
        s, a, next_s = data_gen
        
        # train with real data
        weight_real = self.irl(s_real, a_real, next_s_real)
        loss_real = -weight_real.mean()

        # train with generated data
        weight = self.irl(s, a, next_s)
        loss_gen = weight.mean()
        return loss_real, loss_gen
コード例 #3
0
ファイル: train.py プロジェクト: youngornever/tatk
 def user_loop(self, data):
     batch_input = to_device(padding_data(data))
     a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                      batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])
     
     loss_a, targets_a = 0, batch_input['origin_responses'][:, 1:] # remove sos_id
     for i, a_weight in enumerate(a_weights):
         loss_a += self.nll_loss(a_weight, targets_a[:, i])
     loss_a /= i
     loss_t = self.bce_loss(t_weights, batch_input['terminal'])
     loss_a += self.alpha * kl_gaussian(argu)
     return loss_a, loss_t
コード例 #4
0
ファイル: train.py プロジェクト: youngornever/tatk
    def test(self):
        def sequential(da_seq):
            da = []
            cur_act = None
            for word in da_seq:
                if word in ['<PAD>', '<UNK>', '<SOS>', '<EOS>', '(', ')']:
                    continue
                if '-' in word:
                    cur_act = word
                else:
                    if cur_act is None:
                        continue
                    da.append(cur_act+'-'+word)
            return da
            
        def f1(pred, real):
            if not real:
                return 0, 0, 0
            TP, FP, FN = 0, 0, 0
            for item in real:
                if item in pred:
                    TP += 1
                else:
                    FN += 1
            for item in pred:
                if item not in real:
                    FP += 1
            return TP, FP, FN
    
        data_test_iter = batch_iter(self.data_test[0], self.data_test[1], self.data_test[2], self.data_test[3])
        a_TP, a_FP, a_FN, t_corr, t_tot = 0, 0, 0, 0, 0
        eos_id = self.user.usr_decoder.eos_id
        for i, data in enumerate(data_test_iter):
            batch_input = to_device(padding_data(data))
            a_weights, t_weights, argu = self.user(batch_input['goals'], batch_input['goals_length'], \
                                         batch_input['posts'], batch_input['posts_length'], batch_input['origin_responses'])
            usr_a = []
            for a_weight in a_weights:
                usr_a.append(a_weight.argmax(1).cpu().numpy())
            usr_a = np.array(usr_a).T.tolist()
            a = []
            for ua in usr_a:
                if eos_id in ua:
                    ua = ua[:ua.index(eos_id)]
                a.append(sequential(self.manager.id2sentence(ua)))
            targets_a = []
            for ua_sess in data[1]:
                for ua in ua_sess:
                    targets_a.append(sequential(self.manager.id2sentence(ua[1:-1])))
            TP, FP, FN = f1(a, targets_a)
            a_TP += TP
            a_FP += FP
            a_FN += FN
                    
            t = t_weights.ge(0).cpu().tolist()
            targets_t = batch_input['terminal'].cpu().long().tolist()
            judge = np.array(t) == np.array(targets_t)
            t_corr += judge.sum()
            t_tot += judge.size

        prec = a_TP / (a_TP + a_FP)
        rec = a_TP / (a_TP + a_FN)
        F1 = 2 * prec * rec / (prec + rec)
        print(a_TP, a_FP, a_FN, F1)
        print(t_corr, t_tot, t_corr/t_tot)
コード例 #5
0
    def policy_loop(self, data):
        s, target_a = to_device(data)
        a_weights = self.policy(s)

        loss_a = self.multi_entropy_loss(a_weights, target_a)
        return loss_a