示例#1
0
class OpenPoseTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.pose_visualizer = PoseVisualizer(configer)
        self.pose_parser = PoseParser(configer)
        self.pose_model_manager = ModelManager(configer)
        self.pose_data_loader = DataLoader(configer)
        self.device = torch.device('cpu' if self.configer.get('gpu') is None else 'cuda')
        self.pose_net = None

        self._init_model()

    def _init_model(self):
        self.pose_net = self.pose_model_manager.get_pose_model()
        self.pose_net = RunnerHelper.load_net(self, self.pose_net)
        self.pose_net.eval()

    def _get_blob(self, ori_image, scale=None):
        assert scale is not None
        image = self.blob_helper.make_input(image=ori_image, scale=scale)

        b, c, h, w = image.size()
        border_hw = [h, w]
        if self.configer.exists('test', 'fit_stride'):
            stride = self.configer.get('test', 'fit_stride')

            pad_w = 0 if (w % stride == 0) else stride - (w % stride)  # right
            pad_h = 0 if (h % stride == 0) else stride - (h % stride)  # down

            expand_image = torch.zeros((b, c, h + pad_h, w + pad_w)).to(image.device)
            expand_image[:, :, 0:h, 0:w] = image
            image = expand_image

        return image, border_hw

    def __test_img(self, image_path, json_path, raw_path, vis_path):

        Log.info('Image Path: {}'.format(image_path))
        ori_image = ImageHelper.read_image(image_path,
                                           tool=self.configer.get('data', 'image_tool'),
                                           mode=self.configer.get('data', 'input_mode'))

        ori_width, ori_height = ImageHelper.get_size(ori_image)
        ori_img_bgr = ImageHelper.get_cv2_bgr(ori_image, mode=self.configer.get('data', 'input_mode'))
        heatmap_avg = np.zeros((ori_height, ori_width, self.configer.get('network', 'heatmap_out')))
        paf_avg = np.zeros((ori_height, ori_width, self.configer.get('network', 'paf_out')))
        multiplier = [scale * self.configer.get('test', 'input_size')[1] / ori_height
                      for scale in self.configer.get('test', 'scale_search')]
        stride = self.configer.get('network', 'stride')
        for i, scale in enumerate(multiplier):
            image, border_hw = self._get_blob(ori_image, scale=scale)
            with torch.no_grad():
                paf_out_list, heatmap_out_list = self.pose_net(image)
                paf_out = paf_out_list[-1]
                heatmap_out = heatmap_out_list[-1]

                # extract outputs, resize, and remove padding
                heatmap = heatmap_out.squeeze(0).cpu().numpy().transpose(1, 2, 0)

                heatmap = cv2.resize(heatmap, None, fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
                heatmap = cv2.resize(heatmap[:border_hw[0], :border_hw[1]],
                                     (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)

                paf = paf_out.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                paf = cv2.resize(paf, None, fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
                paf = cv2.resize(paf[:border_hw[0], :border_hw[1]],
                                 (ori_width, ori_height), interpolation=cv2.INTER_CUBIC)

                heatmap_avg = heatmap_avg + heatmap / len(multiplier)
                paf_avg = paf_avg + paf / len(multiplier)

        all_peaks = self.__extract_heatmap_info(heatmap_avg)
        special_k, connection_all = self.__extract_paf_info(ori_img_bgr, paf_avg, all_peaks)
        subset, candidate = self.__get_subsets(connection_all, special_k, all_peaks)
        json_dict = self.__get_info_tree(ori_img_bgr, subset, candidate)

        image_canvas = self.pose_parser.draw_points(ori_img_bgr.copy(), json_dict)
        image_canvas = self.pose_parser.link_points(image_canvas, json_dict)

        ImageHelper.save(image_canvas, vis_path)
        ImageHelper.save(ori_img_bgr, raw_path)
        Log.info('Json Save Path: {}'.format(json_path))
        JsonHelper.save_file(json_dict, json_path)

    def __get_info_tree(self, image_raw, subset, candidate):
        json_dict = dict()
        height, width, _ = image_raw.shape
        json_dict['image_height'] = height
        json_dict['image_width'] = width
        object_list = list()
        for n in range(len(subset)):
            if subset[n][-1] < self.configer.get('res', 'num_threshold'):
                continue

            if subset[n][-2] / subset[n][-1] < self.configer.get('res', 'avg_threshold'):
                continue

            object_dict = dict()
            object_dict['kpts'] = np.zeros((self.configer.get('data', 'num_kpts'), 3)).tolist()
            for j in range(self.configer.get('data', 'num_kpts')):
                index = subset[n][j]
                if index == -1:
                    object_dict['kpts'][j][0] = -1
                    object_dict['kpts'][j][1] = -1
                    object_dict['kpts'][j][2] = -1

                else:
                    object_dict['kpts'][j][0] = candidate[index.astype(int)][0]
                    object_dict['kpts'][j][1] = candidate[index.astype(int)][1]
                    object_dict['kpts'][j][2] = 1

            object_dict['score'] = subset[n][-2]
            object_list.append(object_dict)

        json_dict['objects'] = object_list
        return json_dict

    def __extract_heatmap_info(self, heatmap_avg):
        all_peaks = []
        peak_counter = 0

        for part in range(self.configer.get('data', 'num_kpts')):
            map_ori = heatmap_avg[:, :, part]
            map_gau = gaussian_filter(map_ori, sigma=3)

            map_left = np.zeros(map_gau.shape)
            map_left[1:, :] = map_gau[:-1, :]
            map_right = np.zeros(map_gau.shape)
            map_right[:-1, :] = map_gau[1:, :]
            map_up = np.zeros(map_gau.shape)
            map_up[:, 1:] = map_gau[:, :-1]
            map_down = np.zeros(map_gau.shape)
            map_down[:, :-1] = map_gau[:, 1:]

            peaks_binary = np.logical_and.reduce(
                (map_gau >= map_left, map_gau >= map_right, map_gau >= map_up,
                 map_gau >= map_down, map_gau > self.configer.get('res', 'part_threshold')))

            peaks = zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])  # note reverse
            peaks = list(peaks)

            '''
            del_flag = [0 for i in range(len(peaks))]
            for i in range(len(peaks)):
                if del_flag[i] == 0:
                    for j in range(i+1, len(peaks)):
                        if max(abs(peaks[i][0] - peaks[j][0]), abs(peaks[i][1] - peaks[j][1])) <= 6:
                            del_flag[j] = 1

            new_peaks = list()
            for i in range(len(peaks)):
                if del_flag[i] == 0:
                    new_peaks.append(peaks[i])

            peaks = new_peaks
            '''

            peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks]
            ids = range(peak_counter, peak_counter + len(peaks))
            peaks_with_score_and_id = [peaks_with_score[i] + (ids[i],) for i in range(len(ids))]

            all_peaks.append(peaks_with_score_and_id)
            peak_counter += len(peaks)

        return all_peaks

    def __extract_paf_info(self, img_raw, paf_avg, all_peaks):
        connection_all = []
        special_k = []
        mid_num = self.configer.get('res', 'mid_point_num')

        for k in range(len(self.configer.get('details', 'limb_seq'))):
            score_mid = paf_avg[:, :, [k*2, k*2+1]]
            candA = all_peaks[self.configer.get('details', 'limb_seq')[k][0] - 1]
            candB = all_peaks[self.configer.get('details', 'limb_seq')[k][1] - 1]
            nA = len(candA)
            nB = len(candB)
            if nA != 0 and nB != 0:
                connection_candidate = []
                for i in range(nA):
                    for j in range(nB):
                        vec = np.subtract(candB[j][:2], candA[i][:2])
                        norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + 1e-9
                        vec = np.divide(vec, norm)

                        startend = zip(np.linspace(candA[i][0], candB[j][0], num=mid_num),
                                       np.linspace(candA[i][1], candB[j][1], num=mid_num))
                        startend = list(startend)

                        vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0]
                                          for I in range(len(startend))])
                        vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1]
                                          for I in range(len(startend))])

                        score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
                        score_with_dist_prior = sum(score_midpts) / len(score_midpts)
                        score_with_dist_prior += min(0.5 * img_raw.shape[0] / norm - 1, 0)

                        num_positive = len(np.nonzero(score_midpts > self.configer.get('res', 'limb_threshold'))[0])
                        criterion1 = num_positive > int(self.configer.get('res', 'limb_pos_ratio') * len(score_midpts))
                        criterion2 = score_with_dist_prior > 0
                        if criterion1 and criterion2:
                            connection_candidate.append(
                                [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])

                connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
                connection = np.zeros((0, 5))
                for c in range(len(connection_candidate)):
                    i, j, s = connection_candidate[c][0:3]
                    if i not in connection[:, 3] and j not in connection[:, 4]:
                        connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
                        if len(connection) >= min(nA, nB):
                            break

                connection_all.append(connection)
            else:
                special_k.append(k)
                connection_all.append([])

        return special_k, connection_all

    def __get_subsets(self, connection_all, special_k, all_peaks):
        # last number in each row is the total parts number of that person
        # the second last number in each row is the score of the overall configuration
        subset = -1 * np.ones((0, self.configer.get('data', 'num_kpts') + 2))
        candidate = np.array([item for sublist in all_peaks for item in sublist])

        for k in self.configer.get('details', 'mini_tree'):
            if k not in special_k:
                partAs = connection_all[k][:, 0]
                partBs = connection_all[k][:, 1]
                indexA, indexB = np.array(self.configer.get('details', 'limb_seq')[k]) - 1

                for i in range(len(connection_all[k])):  # = 1:size(temp,1)
                    found = 0
                    subset_idx = [-1, -1]
                    for j in range(len(subset)):  # 1:size(subset,1):
                        if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
                            subset_idx[found] = j
                            found += 1

                    if found == 1:
                        j = subset_idx[0]
                        if (subset[j][indexB] != partBs[i]):
                            subset[j][indexB] = partBs[i]
                            subset[j][-1] += 1
                            subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
                    elif found == 2:  # if found 2 and disjoint, merge them
                        j1, j2 = subset_idx
                        membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
                        if len(np.nonzero(membership == 2)[0]) == 0:  # merge
                            subset[j1][:-2] += (subset[j2][:-2] + 1)
                            subset[j1][-2:] += subset[j2][-2:]
                            subset[j1][-2] += connection_all[k][i][2]
                            subset = np.delete(subset, j2, 0)
                        else:  # as like found == 1
                            subset[j1][indexB] = partBs[i]
                            subset[j1][-1] += 1
                            subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]

                    # if find no partA in the subset, create a new subset
                    elif not found:
                        row = -1 * np.ones(self.configer.get('data', 'num_kpts') + 2)
                        row[indexA] = partAs[i]
                        row[indexB] = partBs[i]
                        row[-1] = 2
                        row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
                        subset = np.vstack([subset, row])

        return subset, candidate

    def debug(self, vis_dir):
        for i, data_dict in enumerate(self.pose_data_loader.get_trainloader()):
            inputs = data_dict['img']
            maskmap = data_dict['maskmap']
            heatmap = data_dict['heatmap']
            vecmap = data_dict['vecmap']
            for j in range(inputs.size(0)):
                count = count + 1
                if count > 10:
                    exit(1)

                Log.info(heatmap.size())
                image_bgr = self.blob_helper.tensor2bgr(inputs[j])
                mask_canvas = maskmap[j].repeat(3, 1, 1).numpy().transpose(1, 2, 0)
                mask_canvas = (mask_canvas * 255).astype(np.uint8)
                mask_canvas = cv2.resize(mask_canvas, (0, 0), fx=self.configer.get('network', 'stride'),
                                         fy=self.configer.get('network', 'stride'), interpolation=cv2.INTER_CUBIC)

                image_bgr = cv2.addWeighted(image_bgr, 0.6, mask_canvas, 0.4, 0)
                heatmap_avg = heatmap[j].numpy().transpose(1, 2, 0)
                heatmap_avg = cv2.resize(heatmap_avg, (0, 0), fx=self.configer.get('network', 'stride'),
                                     fy=self.configer.get('network', 'stride'), interpolation=cv2.INTER_CUBIC)
                paf_avg = vecmap[j].numpy().transpose(1, 2, 0)
                paf_avg = cv2.resize(paf_avg, (0, 0), fx=self.configer.get('network', 'stride'),
                                     fy=self.configer.get('network', 'stride'), interpolation=cv2.INTER_CUBIC)
                self.pose_visualizer.vis_peaks(heatmap_avg, image_bgr)
                self.pose_visualizer.vis_paf(paf_avg, image_bgr)
                all_peaks = self.__extract_heatmap_info(heatmap_avg)
                special_k, connection_all = self.__extract_paf_info(image_bgr, paf_avg, all_peaks)
                subset, candidate = self.__get_subsets(connection_all, special_k, all_peaks)
                json_dict = self.__get_info_tree(image_bgr, subset, candidate)
                image_canvas = self.pose_parser.draw_points(image_bgr, json_dict)
                image_canvas = self.pose_parser.link_points(image_canvas, json_dict)
                cv2.imwrite(os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)), image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()
示例#2
0
class ImageTranslatorTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.model_manager = ModelManager(configer)
        self.test_loader = TestDataLoader(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.gan_net = None

        self._init_model()

    def _init_model(self):
        self.gan_net = self.model_manager.gan_model()
        self.gan_net = RunnerHelper.load_net(self, self.gan_net)
        self.gan_net.eval()

    def test(self, test_dir, out_dir):
        if self.configer.exists('test', 'mode') and self.configer.get(
                'test', 'mode') == 'nir2vis':
            jsonA_path = os.path.join(
                test_dir,
                'val_label{}A.json'.format(self.configer.get('data', 'tag')))
            test_loader_A = self.test_loader.get_testloader(
                json_path=jsonA_path) if os.path.exists(jsonA_path) else None
            jsonB_path = os.path.join(
                test_dir,
                'val_label{}B.json'.format(self.configer.get('data', 'tag')))
            test_loader_B = self.test_loader.get_testloader(
                json_path=jsonB_path) if os.path.exists(jsonB_path) else None
        elif self.configer.exists('test', 'mode') and self.configer.get(
                'test', 'mode') == 'pix2pix':
            imgA_dir = os.path.join(test_dir, 'imageA')
            test_loader_A = self.test_loader.get_testloader(
                test_dir=imgA_dir) if os.path.exists(imgA_dir) else None
            imgB_dir = os.path.join(test_dir, 'imageB')
            test_loader_B = self.test_loader.get_testloader(
                test_dir=imgB_dir) if os.path.exists(imgB_dir) else None
        else:
            imgA_dir = os.path.join(test_dir, 'imageA')
            test_loader_A = self.test_loader.get_testloader(
                test_dir=imgA_dir) if os.path.exists(imgA_dir) else None
            imgB_dir = os.path.join(test_dir, 'imageB')
            test_loader_B = self.test_loader.get_testloader(
                test_dir=imgB_dir) if os.path.exists(imgB_dir) else None

        if test_loader_A is not None:
            for data_dict in test_loader_A:
                new_data_dict = dict(imgA=data_dict['img'], testing=True)
                with torch.no_grad():
                    out_dict = self.gan_net(new_data_dict)

                meta_list = DCHelper.tolist(data_dict['meta'])
                for key, value in out_dict.items():
                    for i in range(len(value)):
                        img_bgr = self.blob_helper.tensor2bgr(value[i])
                        img_path = meta_list[i]['img_path']
                        Log.info('Image Path: {}'.format(img_path))
                        ImageHelper.save(
                            img_bgr,
                            os.path.join(
                                out_dir,
                                '{}_{}.jpg'.format(meta_list[i]['filename'],
                                                   key)))

        if test_loader_B is not None:
            for data_dict in test_loader_B:
                new_data_dict = dict(imgB=data_dict['img'], testing=True)
                with torch.no_grad():
                    out_dict = self.gan_net(new_data_dict)
                meta_list = DCHelper.tolist(data_dict['meta'])
                for key, value in out_dict.items():
                    for i in range(len(value)):
                        img_bgr = self.blob_helper.tensor2bgr(value[i])
                        img_path = meta_list[i]['img_path']
                        Log.info('Image Path: {}'.format(img_path))
                        ImageHelper.save(
                            img_bgr,
                            os.path.join(
                                out_dir,
                                '{}_{}.jpg'.format(meta_list[i]['filename'],
                                                   key)))
示例#3
0
class FaceGANTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.model_manager = ModelManager(configer)
        self.test_loader = TestDataLoader(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.gan_net = None

        self._init_model()

    def _init_model(self):
        self.gan_net = self.model_manager.gan_model()
        self.gan_net = RunnerHelper.load_net(self, self.gan_net)
        self.gan_net.eval()

    def test(self, test_dir, out_dir):
        if self.configer.exists('test', 'mode') and self.configer.get(
                'test', 'mode') == 'nir2vis':
            jsonA_path = os.path.join(
                test_dir,
                'val_label{}A.json'.format(self.configer.get('data', 'tag')))
            test_loader_A = self.test_loader.get_testloader(
                json_path=jsonA_path) if os.path.exists(jsonA_path) else None
            jsonB_path = os.path.join(
                test_dir,
                'val_label{}B.json'.format(self.configer.get('data', 'tag')))
            test_loader_B = self.test_loader.get_testloader(
                json_path=jsonB_path) if os.path.exists(jsonB_path) else None

        else:
            test_loader_A, test_loader_B = None, None
            Log.error('Test Mode not Exists!')
            exit(1)

        assert test_loader_A is not None and test_loader_B is not None
        probe_features = []
        gallery_features = []
        probe_labels = []
        gallery_labels = []
        for data_dict in test_loader_A:
            new_data_dict = dict(imgA=data_dict['img'])
            with torch.no_grad():
                out_dict = self.gan_net(new_data_dict, testing=True)

            meta_list = DCHelper.tolist(data_dict['meta'])

            for idx in range(len(meta_list)):
                probe_features.append(out_dict['featA'][idx].cpu().numpy())
                probe_labels.append(meta_list[idx]['label'])

            for key, value in out_dict.items():
                for i in range(len(value)):
                    if 'feat' in key:
                        continue

                    img_bgr = self.blob_helper.tensor2bgr(value[i])
                    img_path = meta_list[i]['img_path']
                    Log.info('Image Path: {}'.format(img_path))
                    img_bgr = ImageHelper.resize(img_bgr,
                                                 target_size=self.configer.get(
                                                     'test', 'out_size'),
                                                 interpolation='linear')
                    ImageHelper.save(
                        img_bgr,
                        os.path.join(out_dir, key, meta_list[i]['filename']))

        for data_dict in test_loader_B:
            new_data_dict = dict(imgB=data_dict['img'])
            with torch.no_grad():
                out_dict = self.gan_net(new_data_dict, testing=True)

            meta_list = DCHelper.tolist(data_dict['meta'])

            for idx in range(len(meta_list)):
                gallery_features.append(out_dict['feat'][idx].cpu().numpy())
                gallery_labels.append(meta_list[idx]['label'])

            for key, value in out_dict.items():
                for i in range(len(value)):
                    if 'feat' in key:
                        continue

                    img_bgr = self.blob_helper.tensor2bgr(value[i])
                    img_path = meta_list[i]['img_path']
                    Log.info('Image Path: {}'.format(img_path))
                    img_bgr = ImageHelper.resize(img_bgr,
                                                 target_size=self.configer.get(
                                                     'test', 'out_size'),
                                                 interpolation='linear')
                    ImageHelper.save(
                        img_bgr,
                        os.path.join(out_dir, key, meta_list[i]['filename']))

        r_acc, tpr = self.decode(probe_features, gallery_features,
                                 probe_labels, gallery_labels)
        Log.info('Final Rank1 accuracy is {}'.format(r_acc))
        Log.info('Final VR@FAR=0.1% accuracy is {}'.format(tpr))

    @staticmethod
    def decode(probe_features, gallery_features, probe_labels, gallery_labels):
        probe_features = np.array(probe_features)
        gallery_features = np.array(gallery_features)
        score = cosine_similarity(gallery_features, probe_features).T
        # print('score.shape =', score.shape)
        # print('probe_names =', np.array(probe_names).shape)
        # print('gallery_names =', np.array(gallery_names).shape)
        print('===> compute metric')
        # print(probe_names[1], type(probe_names[1]))
        # exit()
        label = np.zeros_like(score)
        maxIndex = np.argmax(score, axis=1)
        # print('len = ', len(maxIndex))
        count = 0
        for i in range(len(maxIndex)):
            probe_names_repeat = np.repeat([probe_labels[i]],
                                           len(gallery_labels),
                                           axis=0).T
            # compare two string list
            result = np.equal(probe_names_repeat, gallery_labels) * 1
            # result = np.core.defchararray.equal(probe_names_repeat, gallery_names) * 1
            # find the index of image in the gallery that has the same name as probe image
            # print(result)
            # print('++++++++++++++++++++++++++++++++=')
            index = np.nonzero(result == 1)

            # if i == 10:
            #     exit()
            assert len(index[0]) == 1
            label[i][index[0][0]] = 1

            # find the max similarty score in gallery has the same name as probe image
            if np.equal(int(probe_labels[i]),
                        int(gallery_labels[maxIndex[i]])):
                count += 1
            else:
                pass
                # print(probe_img_list[i], gallery_img_list[ind])

        r_acc = count / (len(probe_labels) + 1e-5)
        fpr, tpr, thresholds = roc_curve(label.flatten(), score.flatten())
        # print("In sub_experiment", label.size(0), 'count of true label :', count)
        # print('rank1 accuracy =', r_acc)
        # print('VR@FAR=0.1% accuracy =', tpr[fpr <= 0.001][-1])

        # plot_roc(fpr, tpr, thresholds, g_count)
        return r_acc, tpr[fpr <= 0.001][-1]
示例#4
0
class ImageClassifierTest(object):
    def __init__(self, configer):
        self.configer = configer
        self.blob_helper = BlobHelper(configer)
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.cls_parser = ClsParser(configer)
        self.device = torch.device(
            'cpu' if self.configer.get('gpu') is None else 'cuda')
        self.cls_net = None
        if self.configer.get('dataset') == 'imagenet':
            with open(
                    os.path.join(
                        self.configer.get('project_dir'),
                        'datasets/cls/imagenet/imagenet_class_index.json')
            ) as json_stream:
                name_dict = json.load(json_stream)
                name_seq = [
                    name_dict[str(i)][1]
                    for i in range(self.configer.get('data', 'num_classes'))
                ]
                self.configer.add(['details', 'name_seq'], name_seq)

        self._init_model()

    def _init_model(self):
        self.cls_net = self.cls_model_manager.get_cls_model()
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.cls_net.eval()

    def __test_img(self, image_path, json_path, raw_path, vis_path):
        Log.info('Image Path: {}'.format(image_path))
        img = ImageHelper.read_image(
            image_path,
            tool=self.configer.get('data', 'image_tool'),
            mode=self.configer.get('data', 'input_mode'))

        trans = None
        if self.configer.get('dataset') == 'imagenet':
            if self.configer.get('data', 'image_tool') == 'cv2':
                img = Image.fromarray(img)

            trans = transforms.Compose([
                transforms.Scale(256),
                transforms.CenterCrop(224),
            ])

        assert trans is not None
        img = trans(img)

        ori_img_bgr = ImageHelper.get_cv2_bgr(img,
                                              mode=self.configer.get(
                                                  'data', 'input_mode'))

        inputs = self.blob_helper.make_input(img,
                                             input_size=self.configer.get(
                                                 'test', 'input_size'),
                                             scale=1.0)

        with torch.no_grad():
            outputs = self.cls_net(inputs)

        json_dict = self.__get_info_tree(outputs, image_path)

        image_canvas = self.cls_parser.draw_label(ori_img_bgr.copy(),
                                                  json_dict['label'])
        cv2.imwrite(vis_path, image_canvas)
        cv2.imwrite(raw_path, ori_img_bgr)

        Log.info('Json Path: {}'.format(json_path))
        JsonHelper.save_file(json_dict, json_path)
        return json_dict

    def __get_info_tree(self, outputs, image_path=None):
        json_dict = dict()
        if image_path is not None:
            json_dict['image_path'] = image_path

        topk = (1, 3, 5)
        maxk = max(topk)

        _, pred = outputs.topk(maxk, 0, True, True)
        for k in topk:
            if k == 1:
                json_dict['label'] = pred[0]

            else:
                json_dict['label_top{}'.format(k)] = pred[:k]

        return json_dict

    def debug(self, vis_dir):
        count = 0
        for i, data_dict in enumerate(self.cls_data_loader.get_trainloader()):
            inputs = data_dict['img']
            labels = data_dict['label']
            eye_matrix = torch.eye(self.configer.get('data', 'num_classes'))
            labels_target = eye_matrix[labels.view(-1)].view(
                inputs.size(0), self.configer.get('data', 'num_classes'))

            for j in range(inputs.size(0)):
                count = count + 1
                if count > 20:
                    exit(1)

                ori_img_bgr = self.blob_helper.tensor2bgr(inputs[j])

                json_dict = self.__get_info_tree(labels_target[j])
                image_canvas = self.cls_parser.draw_label(
                    ori_img_bgr.copy(), json_dict['label'])

                cv2.imwrite(
                    os.path.join(vis_dir, '{}_{}_vis.png'.format(i, j)),
                    image_canvas)
                cv2.imshow('main', image_canvas)
                cv2.waitKey()