def infer_fg(self, img):
        """
        img: BGR image of shape (H, W, C)
        returns: binary mask image of shape (H, W), 255 for fg, 0 for bg
        """
        ori_h, ori_w = img.shape[0:2]
        new_h, new_w = self.get_working_size(ori_h, ori_w)
        img = cv2.resize(img, (new_w, new_h))

        # Get results of original image
        multiplier = get_multiplier(img)

        with torch.no_grad():
            orig_paf, orig_heat = get_outputs(multiplier, img, self.model,
                                              'rtpose')

            # Get results of flipped image
            swapped_img = img[:, ::-1, :]
            flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img,
                                                    self.model, 'rtpose')

            # compute averaged heatmap and paf
            paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat,
                                               orig_paf, flipped_paf)

        param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
        to_plot, canvas, candidate, subset = decode_pose_fg(
            img, param, heatmap, paf)

        canvas = cv2.resize(canvas, (ori_w, ori_h))
        fg_map = canvas > 128
        canvas[fg_map] = 255
        canvas[~fg_map] = 0
        return canvas[:, :, 0]
def process(model, oriImg, process_speed):
    # Get results of original image
    multiplier = get_multiplier(oriImg, process_speed)
    with torch.no_grad():
        orig_paf, orig_heat = get_outputs(multiplier, oriImg, model, 'rtpose')

        # Get results of flipped image
        swapped_img = oriImg[:, ::-1, :]
        flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img, model,
                                                'rtpose')

        # compute averaged heatmap and paf
        paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat, orig_paf,
                                           flipped_paf)
    param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
    to_plot, canvas, joint_list, person_to_joint_assoc = decode_pose(
        oriImg, param, heatmap, paf)
    return to_plot, canvas, joint_list, person_to_joint_assoc
model = get_model(trunk='vgg19')
model = torch.nn.DataParallel(model).cuda()
model.load_state_dict(torch.load(weight_name))
model.float()
model.eval()
model = model.cuda()

test_image = 'kids.jpg'
oriImg = cv2.imread(test_image)  # B,G,R order
shape_dst = np.min(oriImg.shape[0:2])

# Get results of original image
multiplier = get_multiplier(oriImg)

with torch.no_grad():
    orig_paf, orig_heat = get_outputs(multiplier, oriImg, model, 'rtpose')

    # Get results of flipped image
    swapped_img = oriImg[:, ::-1, :]
    flipped_paf, flipped_heat = get_outputs(multiplier, swapped_img, model,
                                            'rtpose')

    # compute averaged heatmap and paf
    paf, heatmap = handle_paf_and_heat(orig_heat, flipped_heat, orig_paf,
                                       flipped_paf)

param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
canvas, to_plot, candidate, subset = decode_pose(oriImg, param, heatmap, paf)

cv2.imwrite('result_photo.png', to_plot)
        test_img_path = test.joinpath('test_set')
        test_img_name = "image%0d.jpg" % idx
        test_img_path = test_img_path.joinpath(test_img_name)
        test_image = cv2.resize( cv2.imread(str(test_img_path)), (512, 512))
        test_multiplier = get_multiplier(test_image)

        with torch.no_grad():
            train_paf, train_heatmap = get_outputs(train_multiplier, train_image, model, 'rtpose')
            test_paf, test_heatmap = get_outputs(test_multiplier, test_image, model, 'rtpose')

            # use [::-1] to reverse!
            train_swapped_img = train_image[:, ::-1, :]
            test_swapped_img = test_image[:, ::-1, :]


            train_flipped_paf, train_flipped_heat = get_outputs(train_multiplier, train_swapped_img, model, 'rtpose')
            test_flipped_paf, test_flipped_heat = get_outputs(test_multiplier, test_swapped_img, model, 'rtpose')

            train_paf, train_heatmap = handle_paf_and_heat(train_heatmap, train_flipped_heat, train_paf, train_flipped_paf)
            test_paf, test_heatmap = handle_paf_and_heat(test_heatmap, test_flipped_heat, test_paf, test_flipped_paf)


        param = {'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5}
        train_pose = get_pose(param, train_heatmap, train_paf)
        test_pose = get_pose(param, test_heatmap, test_paf)
        pose_name = "pose%0d.jpg" % idx
        cv2.imwrite(str(test_pose_dir.joinpath(pose_name)), test_pose)
        cv2.imwrite(str(train_pose_dir.joinpath(pose_name)), train_pose)