def infer_fg(self, img):
        """
        img: BGR image of shape (H, W, C)
        returns: binary mask image of shape (H, W), 255 for fg, 0 for bg
        """
        ori_h, ori_w = img.shape[0:2]
        new_h, new_w = self.get_working_size(ori_h, ori_w)
        img = cv2.resize(img, (new_w, new_h))

        # Get results of original image
        multiplier = get_multiplier(img)

        with torch.no_grad():
            orig_paf, orig_heat = get_outputs(multiplier, img, self.model,
                                              'rtpose')

            # Get results of flipped image
            swapped_img = img[:, ::-1, :]
            flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img,
                                                    self.model, 'rtpose')

            # compute averaged heatmap and paf
            paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat,
                                               orig_paf, flipped_paf)

        param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
        to_plot, canvas, candidate, subset = decode_pose_fg(
            img, param, heatmap, paf)

        canvas = cv2.resize(canvas, (ori_w, ori_h))
        fg_map = canvas > 128
        canvas[fg_map] = 255
        canvas[~fg_map] = 0
        return canvas[:, :, 0]
Exemplo n.º 2
0
def skeleton_frame(idx):
    img_path = img_dir.joinpath('{:05d}.png'.format(idx))

    img = cv2.imread(str(img_path))

    shape_dst = np.min(img.shape[:2])
    oh = (img.shape[0] - shape_dst) // 2
    ow = (img.shape[1] - shape_dst) // 2

    img = img[oh:oh + shape_dst, ow:ow + shape_dst]
    img = cv2.resize(img, (512, 512))
    multiplier = get_multiplier(img)
    with torch.no_grad():
        paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
    r_heatmap = np.array([remove_noise(ht)
          for ht in heatmap.transpose(2, 0, 1)[:-1]])\
         .transpose(1, 2, 0)
    heatmap[:, :, :-1] = r_heatmap
    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
    label, cord = get_pose(param, heatmap, paf)

    mask = label[:, :] > 0

    intensity = .80
    img[mask, :] = int(255 * intensity)

    fig.clear()
    plt.axis('off')

    plt.imshow(img)
def process(model, oriImg, process_speed):
    # Get results of original image
    multiplier = get_multiplier(oriImg, process_speed)
    with torch.no_grad():
        orig_paf, orig_heat = get_outputs(multiplier, oriImg, model, 'rtpose')

        # Get results of flipped image
        swapped_img = oriImg[:, ::-1, :]
        flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img, model,
                                                'rtpose')

        # compute averaged heatmap and paf
        paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat, orig_paf,
                                           flipped_paf)
    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
    to_plot, canvas, joint_list, person_to_joint_assoc = decode_pose(
        oriImg, param, heatmap, paf)
    return to_plot, canvas, joint_list, person_to_joint_assoc
def save(idx):
    global pose_cords
    if not os.path.exists(str(train_img_dir.joinpath(
            '{:05}.png'.format(idx)))):
        try:
            img_path = img_dir.joinpath('{:05}.png'.format(idx))
            img = cv2.imread(str(img_path))
            shape_dst = np.min(img.shape[:2])
            oh = (img.shape[0] - shape_dst) // 2
            ow = (img.shape[1] - shape_dst) // 2

            img = img[oh:oh + shape_dst, ow:ow + shape_dst]
            img = cv2.resize(img, (512, 512))
            multiplier = get_multiplier(img)
            with torch.no_grad():
                paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
            r_heatmap = np.array([
                remove_noise(ht) for ht in heatmap.transpose(2, 0, 1)[:-1]
            ]).transpose(1, 2, 0)
            heatmap[:, :, :-1] = r_heatmap
            param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
            # TODO get_pose
            label, cord = get_pose(param, heatmap, paf)
            index = 13
            crop_size = 25
            try:
                head_cord = cord[index]
            except:
                try:
                    head_cord = pose_cords[
                        -1]  # if there is not head point in picture, use last frame
                except:
                    head_cord = None

            pose_cords.append(head_cord)
            try:
                head = img[int(head_cord[1] - crop_size):int(head_cord[1] +
                                                             crop_size),
                           int(head_cord[0] - crop_size):int(head_cord[0] +
                                                             crop_size), :]
            except:
                pass
            #    plt.imshow(head)
            plt.savefig(str(train_head_dir.joinpath(
                'pose_{}.jpg'.format(idx))))
            plt.clf()
            cv2.imwrite(str(train_img_dir.joinpath('{:05}.png'.format(idx))),
                        img)
            cv2.imwrite(str(train_label_dir.joinpath('{:05}.png'.format(idx))),
                        label)
            return True
        except:
            return False

    else:
        return False
Exemplo n.º 5
0
def extract_poses(model, save_dir):
    '''make label images for pix2pix'''
    test_img_dir = os.path.join(save_dir, 'test_img')
    os.makedirs(test_img_dir, exist_ok=True)
    test_label_dir = os.path.join(save_dir, 'test_label_ori')
    os.makedirs(test_label_dir, exist_ok=True)
    test_head_dir = os.path.join(save_dir, 'test_head_ori')
    os.makedirs(test_head_dir, exist_ok=True)

    img_dir = os.path.join(save_dir, 'images')

    pose_cords = []
    for idx in tqdm(range(len(os.listdir(img_dir)))):
        img_path = os.path.join(img_dir, '{:05}.png'.format(idx))
        img = cv2.imread(img_path)

        shape_dst = np.min(img.shape[:2])
        oh = (img.shape[0] - shape_dst) // 2
        ow = (img.shape[1] - shape_dst) // 2

        img = img[oh:oh + shape_dst, ow:ow + shape_dst]
        img = cv2.resize(img, (512, 512))
        multiplier = get_multiplier(img)
        with torch.no_grad():
            paf, heatmap = get_outputs(multiplier, img, model, 'rtpose', device)
        r_heatmap = np.array([remove_noise(ht)
                              for ht in heatmap.transpose(2, 0, 1)[:-1]]) \
            .transpose(1, 2, 0)
        heatmap[:, :, :-1] = r_heatmap
        param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
        label, cord = get_pose(param, heatmap, paf)
        index = 13
        crop_size = 25
        try:
            head_cord = cord[index]
        except:
            head_cord = pose_cords[-1] # if there is not head point in picture, use last frame

        pose_cords.append(head_cord)
        head = img[int(head_cord[1] - crop_size): int(head_cord[1] + crop_size),
                   int(head_cord[0] - crop_size): int(head_cord[0] + crop_size), :]
        plt.imshow(head)
        plt.savefig(os.path.join(test_head_dir, 'pose_{}.jpg'.format(idx)))
        plt.clf()
        cv2.imwrite(os.path.join(test_img_dir, '{:05}.png'.format(idx)), img)
        cv2.imwrite(os.path.join(test_label_dir, '{:05}.png'.format(idx)), label)
        if idx % 100 == 0 and idx != 0:
            pose_cords_arr = np.array(pose_cords, dtype=np.int)
            np.save(os.path.join(save_dir, 'pose_source.npy'), pose_cords_arr)

    pose_cords_arr = np.array(pose_cords, dtype=np.int)
    np.save(os.path.join(save_dir, 'pose_source.npy'), pose_cords_arr)
    torch.cuda.empty_cache()
def get_pose_sparse_img(video_name, index, model):
    path = '/home/molly/UCF_data/jpegs_256/v_' + video_name
    img_path = path + '/frame' + str(index).zfill(6) + '.jpg'
    print(img_path)
    img = cv2.imread(img_path)
    shape_dst = np.min(img.shape[0:2])
    with torch.no_grad():
        paf, heatmap, im_scale = get_outputs(img, model, 'rtpose')
    humans = paf_to_pose_cpp(heatmap, paf, cfg)
    image_h, image_w = img.shape[:2]
    pose_image = np.zeros((image_h, image_w), dtype="uint8")
    pose_image = cv2.cvtColor(pose_image, cv2.COLOR_GRAY2BGR)
    centers = {}
    for human in humans:
        # draw point
        for i in range(CocoPart.Background.value):
            if i not in human.body_parts.keys():
                continue
            body_part = human.body_parts[i]
            #print(body_part)
            center = (int(body_part.x * image_w + 0.5),
                      int(body_part.y * image_h + 0.5))
            centers[i] = center
            cv2.circle(pose_image,
                       center,
                       3,
                       CocoColors[i],
                       thickness=3,
                       lineType=8,
                       shift=0)

        # draw line
        for pair_order, pair in enumerate(CocoPairsRender):
            if pair[0] not in human.body_parts.keys(
            ) or pair[1] not in human.body_parts.keys():
                continue
            cv2.line(pose_image, centers[pair[0]], centers[pair[1]],
                     CocoColors[pair_order], 3)
    pose_img_resize = cv2.resize(pose_image, (224, 224))
    #print(pose_img_resize.shape)
    if not os.path.exists('/home/molly/UCF_data/pose_flow/' + video_name +
                          '/'):
        os.mkdir('/home/molly/UCF_data/pose_flow/' + video_name + '/')
    cv2.imwrite(
        '/home/molly/UCF_data/pose_flow/' + video_name + '/frame' +
        str(index).zfill(6) + '.jpg', pose_img_resize)
Exemplo n.º 7
0
def detect_keypoints(oriImg):
    # Get results of original image
    with torch.no_grad():
        paf, heatmap, im_scale = get_outputs(oriImg, model,  'rtpose')
    
    # heatmap2d(paf[:,:,1])
    print(paf.shape, heatmap.shape, im_scale)
    humans = paf_to_pose_cpp(heatmap, paf, cfg)

    image_h, image_w = oriImg.shape[:2]
    human_keypoints = []
    for human in humans:
        # draw point
        centers = {}
        for i in range(CocoPart.Background.value):
            if i not in human.body_parts.keys():
                continue

            body_part = human.body_parts[i]
            center = (int(body_part.x * image_w + 0.5), int(body_part.y * image_h + 0.5))
            centers[i] = {'center': center, 'score': body_part.score}
        human_keypoints.append(centers)

    return human_keypoints
Exemplo n.º 8
0
if __name__ == "__main__":
    
    video_capture = cv2.VideoCapture(0)

    while True:
        # Capture frame-by-frame
        ret, oriImg = video_capture.read()
        
        shape_dst = np.min(oriImg.shape[0:2])

        # Get results of original image
        multiplier = get_multiplier(oriImg)

        with torch.no_grad():
            paf, heatmap = get_outputs(
                multiplier, oriImg, model,  'rtpose')
                  
    heatmap_peaks = np.zeros_like(heatmap)
    for i in range(19):
        heatmap_peaks[:,:,i] = find_peaks(heatmap[:,:,i])
    heatmap_peaks = heatmap_peaks.astype(np.float32)
    heatmap = heatmap.astype(np.float32)
    paf = paf.astype(np.float32)

    #C++ postprocessing      
    pafprocess.process_paf(heatmap_peaks, heatmap, paf)

    humans = []
    for human_id in range(pafprocess.get_num_humans()):
        human = Human([])
        is_added = False
Exemplo n.º 9
0
def main(args):

    sys.path.append(
        args.openpose_dir)  # In case calling from an external script
    from lib.network.rtpose_vgg import get_model
    from lib.network.rtpose_vgg import use_vgg
    from lib.network import im_transform
    from evaluate.coco_eval import get_outputs, handle_paf_and_heat
    from lib.utils.common import Human, BodyPart, CocoPart, CocoColors, CocoPairsRender, draw_humans
    from lib.utils.paf_to_pose import paf_to_pose_cpp
    from lib.config import cfg, update_config

    update_config(cfg, args)

    model = get_model('vgg19')
    model = torch.nn.DataParallel(model).cuda()
    use_vgg(model)

    # model.load_state_dict(torch.load(args.weight))
    checkpoint = torch.load(args.weight)
    epoch = checkpoint['epoch']
    best_loss = checkpoint['best_loss']
    state_dict = checkpoint['state_dict']
    # state_dict = {key.replace("module.",""):value for key, value in state_dict.items()} # Remove "module." from vgg keys
    model.load_state_dict(state_dict)
    # optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {})".format(args.weight, epoch))

    model.float()
    model.eval()

    image_folders = args.image_folders.split(',')

    for i, image_folder in enumerate(image_folders):
        print(
            f"\nProcessing {i} of {len(image_folders)}: {' '.join(image_folder.split('/')[-4:-2])}"
        )

        if args.all_frames:  # Split video and run inference on all frames
            output_dir = os.path.join(os.path.dirname(image_folder),
                                      'predictions', 'pose2d',
                                      'openpose_pytorch_ft_all')
            os.makedirs(output_dir, exist_ok=True)
            video_path = os.path.join(
                image_folder,
                'scan_video.avi')  # break up video and run on all frames
            temp_folder = image_folder.split('/')[-3] + '_openpose'
            image_folder = os.path.join(
                '/tmp', f'{temp_folder}')  # Overwrite image_folder
            os.makedirs(image_folder, exist_ok=True)
            split_video(video_path, image_folder)
        else:  # Just use GT-annotated frames
            output_dir = os.path.join(os.path.dirname(image_folder),
                                      'predictions', 'pose2d',
                                      'openpose_pytorch_ft')
            os.makedirs(output_dir, exist_ok=True)

        img_mask = os.path.join(image_folder, '??????.png')
        img_names = glob(img_mask)
        for img_name in img_names:
            image_file_path = img_name

            oriImg = cv2.imread(image_file_path)  # B,G,R order
            shape_dst = np.min(oriImg.shape[0:2])

            with torch.no_grad():
                paf, heatmap, im_scale = get_outputs(oriImg, model, 'rtpose')

            humans = paf_to_pose_cpp(heatmap, paf, cfg)

            # Save joints in OpenPose format
            image_h, image_w = oriImg.shape[:2]
            people = []
            for i, human in enumerate(humans):
                keypoints = []
                for j in range(18):
                    if j == 8:
                        keypoints.extend([
                            0, 0, 0
                        ])  # Add extra joint (midhip) to correspond to body_25
                    if j not in human.body_parts.keys():
                        keypoints.extend([0, 0, 0])
                    else:
                        body_part = human.body_parts[j]
                        keypoints.extend([
                            body_part.x * image_w, body_part.y * image_h,
                            body_part.score
                        ])
                person = {"person_id": [i - 1], "pose_keypoints_2d": keypoints}
                people.append(person)
            people_dict = {"people": people}

            _, filename = os.path.split(image_file_path)
            name, _ = os.path.splitext(filename)
            frame_id = int(name)
            with open(
                    os.path.join(output_dir,
                                 f"scan_video_{frame_id:012}_keypoints.json"),
                    'w') as outfile:
                json.dump(people_dict, outfile)

        if args.all_frames:
            shutil.rmtree(image_folder)  # Delete image_folder
Exemplo n.º 10
0
if __name__ == "__main__":

    video_capture = cv2.VideoCapture(0)

    while True:
        # Capture frame-by-frame
        ret, oriImg = video_capture.read()

        shape_dst = np.min(oriImg.shape[0:2])

        # Get results of original image
        multiplier = get_multiplier(oriImg)

        with torch.no_grad():
            paf, heatmap = get_outputs(shape_dst, model, 'rtpose')
        #with torch.no_grad():
        #paf, heatmap = get_outputs(oriImg, model,  'rtpose')

        heatmap_peaks = np.zeros_like(heatmap)
        for i in range(19):
            heatmap_peaks[:, :, i] = find_peaks(heatmap[:, :, i])
        heatmap_peaks = heatmap_peaks.astype(np.float32)
        heatmap = heatmap.astype(np.float32)
        paf = paf.astype(np.float32)

        #C++ postprocessing
        pafprocess.process_paf(heatmap_peaks, heatmap, paf)

        humans = []
        for human_id in range(pafprocess.get_num_humans()):
                    help="Modify config options using the command-line",
                    default=None,
                    nargs=argparse.REMAINDER)
args = parser.parse_args()

# update config file
update_config(cfg, args)

weight_name = '/home/tensorboy/Downloads/pose_model.pth'

model = get_model('vgg19')
model.load_state_dict(torch.load(weight_name))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()

test_image = './readme/ski.jpg'
oriImg = cv2.imread(test_image)  # B,G,R order
shape_dst = np.min(oriImg.shape[0:2])

# Get results of original image

with torch.no_grad():
    paf, heatmap, im_scale = get_outputs(oriImg, model, 'rtpose')

print(im_scale)
humans = paf_to_pose_cpp(heatmap, paf, cfg)

out = draw_humans(oriImg, humans)
cv2.imwrite('result.png', out)
model = get_model(trunk='vgg19')
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(weight_name))
model.float()
model.eval()
model = model.cuda()

test_image = 'kids.jpg'
oriImg = cv2.imread(test_image)  # B,G,R order
shape_dst = np.min(oriImg.shape[0:2])

# Get results of original image
multiplier = get_multiplier(oriImg)

with torch.no_grad():
    orig_paf, orig_heat = get_outputs(multiplier, oriImg, model, 'rtpose')

    # Get results of flipped image
    swapped_img = oriImg[:, ::-1, :]
    flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img, model,
                                            'rtpose')

    # compute averaged heatmap and paf
    paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat, orig_paf,
                                       flipped_paf)

param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
canvas, to_plot, candidate, subset = decode_pose(oriImg, param, heatmap, paf)

cv2.imwrite('result_photo.png', to_plot)
Exemplo n.º 13
0
def generate(origin_img, img_dir, label_dir, size_dst, size_crop, crop_from, pose_transform=False):
    # Pose estimation (OpenPose)
    openpose_dir = Path('../src/pytorch_Realtime_Multi-Person_Pose_Estimation/')

    sys.path.append(str(openpose_dir))
    sys.path.append('../src/utils')
    # from Pose estimation
    from evaluate.coco_eval import get_multiplier, get_outputs
    # utils
    from openpose_utils import remove_noise, get_pose, get_pose_coord, get_pose_new


    model = pose_model()

    total = len(list(origin_img.iterdir()))
    img_idx = range(total)

    if pose_transform:
        ratio_src, ratio_tar = '../data/source/ratio_a.png', '../data/target/ratio_b.png'
        if not os.path.isfile(ratio_src):
            raise TypeError('Directory not exists: {}'.format(ratio_src))
        if not os.path.isfile(ratio_tar):
            raise TypeError('Directory not exists: {}'.format(ratio_tar))

        imgset = [ratio_src, ratio_tar]
        origin = []
        height = []
        ratio = {'0-1': None, '1-2': None, '2-3': None, '3-4': None, '1-8': None, '8-9': None,
                 '9-10': None, '0-14': None, '14-16': None}  # target/source
        coord = {'0-1': [], '1-2': [], '2-3': [], '3-4': [], '1-8': [], '8-9': [], '9-10': [], '0-14':[], '14-16':[]}  # len of joint
        # co_tar = {'0-1':None, '1-2':None, '2-3':None,'3-4':None,'1-8':None,'8-9':None,'9-10':None}

        for img_path in imgset:
            img = cv2.imread(str(img_path))
            if not img.shape[:2] == size_dst[::-1]:  # format: (h, w)
                img = img_resize(img, size_crop, crop_from, size_dst)  # size_dst format: (W, H)
            multiplier = get_multiplier(img)
            with torch.no_grad():
                paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
            r_heatmap = np.array([remove_noise(ht)
                                  for ht in heatmap.transpose(2, 0, 1)[:-1]]) \
                .transpose(1, 2, 0)
            heatmap[:, :, :-1] = r_heatmap
            param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}  # only 'thre2' matters

            label, joint_list = get_pose_coord(img, param, heatmap, paf)
            # print ('joint list: \n',joint_list)

            origin.append(joint_list[1][0][:2])  # we set the no.1 pose (neck) as the original ref. point
            height_max = max(joint_list, key=lambda x: x[0][1])[0][1]
            height_min = min(joint_list, key=lambda x: x[0][1])[0][1]
            height.append(height_max - height_min)

            for k in ratio.keys():
                klist = k.split('-')
                j_1, j_2 = int(klist[0]), int(klist[-1])
                # assert j_1 == int(joint_list[j_1][0][-1]) and j_2 == int(
                #    joint_list[j_2][0][-1])  # may cause issue if empty array exists
                co_1, co_2 = list(joint_list[j_1][0][:2]), list(joint_list[j_2][0][:2])
                j_len = ((co_1[0] - co_2[0]) ** 2 + (co_1[1] - co_2[1]) ** 2) ** 0.5
                coord[k].append(j_len)

        for k, v in coord.items():
            src_len, tar_len = v[0], v[1]
            ratio[k] = tar_len / src_len

        ratio_body = height[1] / height[0]  # target / source height
        print('ratio:\n', ratio, '\nratio_body:', ratio_body)  # test only

    for idx in tqdm(img_idx):
        img_path = origin_img.joinpath('img_{:04d}.png'.format(idx))
        img = cv2.imread(str(img_path))

        if not img.shape[:2] == size_dst[::-1]:
            # set crop size and resize
            img = img_resize(img, size_crop, crop_from, size_dst)  # size format: (W, H)

        multiplier = get_multiplier(img)
        with torch.no_grad():
            paf, heatmap = get_outputs(multiplier, img, model, 'rtpose')
        r_heatmap = np.array([remove_noise(ht)
                              for ht in heatmap.transpose(2, 0, 1)[:-1]]) \
            .transpose(1, 2, 0)
        heatmap[:, :, :-1] = r_heatmap
        param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}  # only thre2 makes effect

        if pose_transform:
            _, joint_list = get_pose_coord(img, param, heatmap, paf)
            #print('joint_list', '\n', joint_list)  # test only
            new_joint = translate(joint_list, ratio, origin, ratio_body)
            new_joint_list = new_joint.run()
            #print('joint_list new', '\n', new_joint_list)  # test only
            """
            with open('joint_list.txt','a') as f:
                f.write('joint_list_{}\n'.format(idx)+str(joint_list)+'\nnew_joint_list_{}\n'.format(idx)+str(new_joint_list)+'\n')
            """
            label = get_pose_new(img, param, heatmap, paf, new_joint_list)
        else:
            label = get_pose(img, param, heatmap, paf)  # size changed !!!

        cv2.imwrite(str(img_dir.joinpath('img_{:04d}.png'.format(idx))), img)
        cv2.imwrite(str(label_dir.joinpath('label_{:04d}.png'.format(idx))), label)

    torch.cuda.empty_cache()  #
    print(str(total) + ' ' + str(origin_img.parent.name) + ' images are generated')
Exemplo n.º 14
0
    for file in files:
        os.remove(os.path.join(root, file))
        

# webcam image source
video_capture = cv2.VideoCapture(0)
imgnum=0

while True:
    
    ret, frame = video_capture.read()      
           
    shape_dst = np.min(frame.shape[0:2])   
       
    with torch.no_grad():
        paf, heatmap, im_scale = get_outputs(frame, model,  'rtpose')              
    
    humans = paf_to_pose_cpp(heatmap, paf, cfg)
            
    out = draw_humans(frame, humans)      
        
    cv2.imshow('Video', out)    
    
    # skip the first image and write the rest of the stream
    if (imgnum > 0):
        cv2.imwrite('./LiveImages/image' + str(imgnum) + '.png',out) 

    # break loop with key press
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
    
    for idx in range(200, 210):
        train_img_path = train.joinpath('train_set')
        train_img_name = "image%0d.jpg" % idx
        train_img_path = train_img_path.joinpath(train_img_name)
        train_image = cv2.resize( cv2.imread(str(train_img_path)), (512, 512))
        train_multiplier = get_multiplier(train_image)

        test_img_path = test.joinpath('test_set')
        test_img_name = "image%0d.jpg" % idx
        test_img_path = test_img_path.joinpath(test_img_name)
        test_image = cv2.resize( cv2.imread(str(test_img_path)), (512, 512))
        test_multiplier = get_multiplier(test_image)

        with torch.no_grad():
            train_paf, train_heatmap = get_outputs(train_multiplier, train_image, model, 'rtpose')
            test_paf, test_heatmap = get_outputs(test_multiplier, test_image, model, 'rtpose')

            # use [::-1] to reverse!
            train_swapped_img = train_image[:, ::-1, :]
            test_swapped_img = test_image[:, ::-1, :]


            train_flipped_paf, train_flipped_heat = get_outputs(train_multiplier, train_swapped_img, model, 'rtpose')
            test_flipped_paf, test_flipped_heat = get_outputs(test_multiplier, test_swapped_img, model, 'rtpose')

            train_paf, train_heatmap = handle_paf_and_heat(train_heatmap, train_flipped_heat, train_paf, train_flipped_paf)
            test_paf, test_heatmap = handle_paf_and_heat(test_heatmap, test_flipped_heat, test_paf, test_flipped_paf)


        param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}