예제 #1
0
    def result_analysis(self):
        prediction = self.decode_batch()

        batch_size = self.softmax.size(1)
        delete_total = 0
        replace_total = 0
        insert_total = 0
        len_total = 0
        word_total = 0
        for i in range(batch_size):
            pre_list = prediction[i]
            label_list = self.label[i][self.label[i] != -1].tolist()
            label_list = [int(ele) + 1 for ele in label_list]
            distance, (delete, replace,
                       insert) = cal_distance(label_list, pre_list)
            delete_total += delete
            replace_total += replace
            insert_total += insert
            len_total += len(label_list)
            if distance == 0:
                word_total += 1

        result = [
            delete_total, replace_total, insert_total, len_total, word_total
        ]
        return prediction, result
예제 #2
0
    def result_analysis_recall(self):
        prediction = self.decode_batch()
        batch_size = self.softmax.size(1)
        delete_total = 0
        replace_total = 0
        insert_total = 0
        len_total = 0
        correct_count = 0
        pre_total = 0
        word_total = 0
        all_total = 0
        for i in range(batch_size):
            pre_list = prediction[i]
            label_list = self.label[i][self.label[i] != -1].tolist()
            label_list = [int(ele) + 1 for ele in label_list]

            temp_label_list = list()
            for k in range(len(label_list)):
                temp_label_list.append(label_list[k])

            distance, (delete, replace,
                       insert) = cal_distance(label_list, pre_list)
            delete_total += delete
            replace_total += replace
            insert_total += insert
            len_total += len(label_list)
            pre_total += len(pre_list)
            correct_count += len(label_list) - delete - replace
            if distance == 0:
                word_total += 1
            # else:
            #     alphabet = u'_深秦京海成南杭苏松0123456789ABCDEFGHJKLMNPQRSTUVWXYZ'
            #     label_list.append(0)
            #     label_list.extend(pre_list)
            #     file_name = 'wrong_img/%d_%s.jpg'%(self.count_n, ''.join([alphabet[j] for j in label_list]))
            #     cv2.imwrite(file_name, self.image[i]*127.5+127.5)
            #     self.count_n += 1
            all_total += 1
        result = [
            delete_total, replace_total, insert_total, len_total,
            correct_count, len_total, pre_total, word_total, all_total
        ]
        return prediction, result
예제 #3
0
    def result_analysis_recall(self):
        prediction = self.decode_batch()
        batch_size = self.softmax.size(1)
        delete_total = 0
        replace_total = 0
        insert_total = 0
        len_total = 0
        correct_count = 0
        pre_total = 0
        word_total = 0
        all_total = 0

        for i in range(batch_size):
            pre_list = prediction[i]
            label_list = self.label[i][self.label[i] != -1].tolist()
            label_list = [int(ele) + 1 for ele in label_list]

            distance, (delete, replace,
                       insert) = cal_distance(label_list, pre_list)
            delete_total += delete
            replace_total += replace
            insert_total += insert
            len_total += len(label_list)
            pre_total += len(pre_list)
            correct_count += len(label_list) - delete - replace

            if distance == 0:
                word_total += 1
            all_total += 1
        # print (correct_count , len_total, pre_total)
        # recall = float(correct_count) / len_total
        # precision = correct_count / (pre_total+0.000001)
        result = [
            delete_total, replace_total, insert_total, len_total,
            correct_count, len_total, pre_total, word_total, all_total
        ]
        return prediction, result
예제 #4
0
    def result_analysis_recall(self):

        result = self.softmax.data.topk(1)[1]
        result = result.squeeze();
        result = result.type(torch.LongTensor)

        batch_size,_ = self.label.size()
        label_batch = torch.transpose(self.label,0,1)
        # label_batch = label_batch.data

        delete_total = 0
        replace_total = 0
        insert_total = 0
        len_total = 0
        word_total = 0
        all_total = 0
        pre_total = 0
        correct_count = 0
        # print(self.processed_batches)
        # if self.processed_batches > 5000:
        show_result_pred = ''
        show_result_label = ''
        for i in range(batch_size):
            pre_list = []
            # pdb.set_trace()
            for ln in result[:,i]:
                ln = ln.item()
                if ln != self.EOS_token:
                    pre_list.append(ln);
                else:
                    break;
            label_list = []
            for ln in label_batch[:,i]:
                if ln != -1:
                    label_list.append(int(ln));
                else:
                    break;
            # if i == 0:
                # print('label_list:', label_list)
                # print('pre_list:', pre_list)            

            # if (self.training and i == 0) or np.random.rand() < 0.0001:
            #     alphabet = '_$!$#$"$\'$&$)$($+$*$-$,$/$.$1$0$3$2$5$4$7$6$9$8$;$:$?$A$C$B$E$D$G$F$I$H$K$J$M$L$O$N$Q$P$S$R$U$T$W$V$Y$X$Z$a$c$b$e$d$g$f$i$h$k$j$m$l$o$n$q$p$s$r$u$t$w$v$y$x$z$|'.split('$')
            #     print(''.join([alphabet[i+1] for i in label_list]))
            #     print(''.join([alphabet[i+1] for i in pre_list]))

            # if self.training and i <= debug[1]:
            # if i < self.debug[0]:
            if i == 0:
                alphabet = '_$!$#$"$\'$&$)$($+$*$-$,$/$.$1$0$3$2$5$4$7$6$9$8$;$:$?$A$C$B$E$D$G$F$I$H$K$J$M$L$O$N$Q$P$S$R$U$T$W$V$Y$X$Z$a$c$b$e$d$g$f$i$h$k$j$m$l$o$n$q$p$s$r$u$t$w$v$y$x$z$|'.split('$')
                show_result_pred = show_result_pred + '  ' + ''.join([alphabet[i+1] for i in pre_list])
                show_result_label = show_result_label + '  ' + ''.join([alphabet[i+1] for i in label_list])

            distance, (delete, replace, insert) = cal_distance(label_list, pre_list)
            delete_total += delete
            replace_total += replace
            insert_total += insert            
            len_total += len(label_list)
            pre_total += len(pre_list)
            correct_count += len(label_list) - delete - replace;
            if distance == 0:
                word_total += 1
            all_total += 1 
                       
        if self.training or np.random.rand() < 0.01:
            # print('')
            # print(self.image_count)
            print('label:', show_result_label)
            print('pred: ',show_result_pred)
            #print(np.array(self.debug[2]))

            # from PIL import Image, ImageDraw
            # im = Image.fromarray(self.debug[1][0,0].cpu().data.numpy()*127.5+127.5)
            # draw = ImageDraw.Draw(im)
            # for pos in self.debug[2]:
            #     x = (pos-1)*16+16
            #     draw.ellipse((x-8,160-8, x+8, 160+8), fill = 128)                
            # im = im.convert('RGB')
            # if self.training:
            #     im.save('/media/pci/256c7be3-9106-49d7-a195-2dbc53b60824/zecheng/output/tmp_output001/image/00train_%05d.jpg'%  self.image_count)
            # else:
            #     im.save('/media/pci/256c7be3-9106-49d7-a195-2dbc53b60824/zecheng/output/tmp_output001/image/01test_%05d.jpg'%  self.image_count)

            # self.image_count += 1            


        rec_result = [delete_total, replace_total, insert_total, len_total, correct_count, len_total, pre_total, word_total, all_total]
        return result, rec_result