Ejemplo n.º 1
0
def main(opts):

    with open(opts.txt_file, 'r') as f:
        img_list = f.readlines()

    net = grapy_net.GrapyMutualLearning(os=16,
                                        hidden_layers=opts.hidden_graph_layers)

    if gpu_id >= 0:
        net.cuda()

    if not opts.resume_model == '':
        x = torch.load(opts.resume_model)
        net.load_state_dict(x)

        print('resume model:', opts.resume_model)

    else:
        print('we are not resuming from any model')

    if opts.dataset == 'cihp':
        val = cihp.VOCSegmentation
        val_flip = cihp.VOCSegmentation

        vis_dir = '/cihp_output_vis/'
        mat_dir = '/cihp_output/'

        num_dataset_lbl = 0

    elif opts.dataset == 'pascal':

        val = pascal.VOCSegmentation
        val_flip = pascal.VOCSegmentation

        vis_dir = '/pascal_output_vis/'
        mat_dir = '/pascal_output/'

        num_dataset_lbl = 1

    elif opts.dataset == 'atr':
        val = atr.VOCSegmentation
        val_flip = atr.VOCSegmentation

        vis_dir = '/atr_output_vis/'
        mat_dir = '/atr_output/'

        print("atr_num")
        num_dataset_lbl = 2

    ## multi scale
    scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
    testloader_list = []
    testloader_flip_list = []
    for pv in scale_list:
        composed_transforms_ts = transforms.Compose(
            [tr.Scale_(pv),
             tr.Normalize_xception_tf(),
             tr.ToTensor_()])

        composed_transforms_ts_flip = transforms.Compose([
            tr.Scale_(pv),
            tr.HorizontalFlip(),
            tr.Normalize_xception_tf(),
            tr.ToTensor_()
        ])

        voc_val = val(split='val', transform=composed_transforms_ts)
        voc_val_f = val_flip(split='val',
                             transform=composed_transforms_ts_flip)

        testloader = DataLoader(voc_val,
                                batch_size=1,
                                shuffle=False,
                                num_workers=4)
        testloader_flip = DataLoader(voc_val_f,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=4)

        testloader_list.append(copy.deepcopy(testloader))
        testloader_flip_list.append(copy.deepcopy(testloader_flip))

    print("Eval Network")

    if not os.path.exists(opts.output_path + vis_dir):
        os.makedirs(opts.output_path + vis_dir)
    if not os.path.exists(opts.output_path + mat_dir):
        os.makedirs(opts.output_path + mat_dir)

    start_time = timeit.default_timer()
    # One testing epoch
    total_iou = 0.0

    c1, c2, p1, p2, a1, a2 = [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],\
           [[0], [1, 2, 4, 13], [5, 6, 7, 10, 11, 12], [3, 14, 15], [8, 9, 16, 17, 18, 19]], \
           [[0], [1, 2, 3, 4, 5, 6]], [[0], [1], [2], [3, 4], [5, 6]], [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]],\
           [[0], [1, 2, 3, 11], [4, 5, 7, 8, 16, 17], [14, 15], [6, 9, 10, 12, 13]]

    net.set_category_list(c1, c2, p1, p2, a1, a2)

    net.eval()

    with torch.no_grad():
        for ii, large_sample_batched in enumerate(
                zip(*testloader_list, *testloader_flip_list)):
            print(ii)
            #1 0.5 0.75 1.25 1.5 1.75 ; flip:
            sample1 = large_sample_batched[:6]
            sample2 = large_sample_batched[6:]
            for iii, sample_batched in enumerate(zip(sample1, sample2)):
                print(sample_batched[0]['image'].shape)
                print(sample_batched[1]['image'].shape)
            assert False

            for iii, sample_batched in enumerate(zip(sample1, sample2)):
                inputs, labels_single = sample_batched[0][
                    'image'], sample_batched[0]['label']
                inputs_f, labels_single_f = sample_batched[1][
                    'image'], sample_batched[1]['label']
                inputs = torch.cat((inputs, inputs_f), dim=0)
                labels = torch.cat((labels_single, labels_single_f), dim=0)

                if iii == 0:
                    _, _, h, w = inputs.size()
                # assert inputs.size() == inputs_f.size()

                # Forward pass of the mini-batch
                inputs, labels = Variable(
                    inputs, requires_grad=False), Variable(labels)

                with torch.no_grad():
                    if gpu_id >= 0:
                        inputs, labels, labels_single = inputs.cuda(
                        ), labels.cuda(), labels_single.cuda()
                    # outputs = net.forward(inputs)
                    # pdb.set_trace()
                    outputs, outputs_aux = net.forward(
                        (inputs, num_dataset_lbl), training=False)

                    # print(outputs.shape, outputs_aux.shape)
                    if opts.dataset == 'cihp':
                        outputs = (outputs[0] +
                                   flip(flip_cihp(outputs[1]), dim=-1)) / 2
                    elif opts.dataset == 'pascal':
                        outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
                    else:
                        outputs = (outputs[0] +
                                   flip(flip_atr(outputs[1]), dim=-1)) / 2

                    outputs = outputs.unsqueeze(0)

                    if iii > 0:
                        outputs = F.upsample(outputs,
                                             size=(h, w),
                                             mode='bilinear',
                                             align_corners=True)
                        outputs_final = outputs_final + outputs
                    else:
                        outputs_final = outputs.clone()

            ################ plot pic
            predictions = torch.max(outputs_final, 1)[1]
            prob_predictions = torch.max(outputs_final, 1)[0]
            results = predictions.cpu().numpy()
            prob_results = prob_predictions.cpu().numpy()
            vis_res = decode_labels(results)

            parsing_im = Image.fromarray(vis_res[0])
            parsing_im.save(opts.output_path + vis_dir +
                            '{}.png'.format(img_list[ii][:-1]))
            cv2.imwrite(
                opts.output_path + mat_dir +
                '{}.png'.format(img_list[ii][:-1]), results[0, :, :])

        # total_iou += utils.get_iou(predictions, labels)
    end_time = timeit.default_timer()
    print('time use for ' + str(ii) + ' is :' + str(end_time - start_time))

    # Eval
    pred_path = opts.output_path + mat_dir
    eval_with_numpy(pred_path=pred_path,
                    gt_path=opts.gt_path,
                    classes=opts.classes,
                    txt_file=opts.txt_file,
                    dataset=opts.dataset)
def main(opts):
    adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
    adj2_test = (adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7,
                                                        20).cuda().transpose(
                                                            2, 3))

    adj1_ = Variable(
        torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
    adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()

    cihp_adj = graph.preprocess_adj(graph.cihp_graph)
    adj3_ = Variable(torch.from_numpy(cihp_adj).float())
    adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()

    p = OrderedDict()  # Parameters to include in report
    p["trainBatch"] = opts.batch  # Training batch size
    p["nAveGrad"] = 1  # Average the gradient of several iterations
    p["lr"] = opts.lr  # Learning rate
    p["lrFtr"] = 1e-5
    p["lraspp"] = 1e-5
    p["lrpro"] = 1e-5
    p["lrdecoder"] = 1e-5
    p["lrother"] = 1e-5
    p["wd"] = 5e-4  # Weight decay
    p["momentum"] = 0.9  # Momentum
    p["epoch_size"] = 10  # How many epochs to change learning rate
    p["num_workers"] = opts.numworker
    backbone = "xception"  # Use xception or resnet as feature extractor,

    with open(opts.txt_file, "r") as f:
        img_list = f.readlines()

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split("/")[-1]
    runs = glob.glob(os.path.join(save_dir_root, "run", "run_*"))
    for r in runs:
        run_id = int(r.split("_")[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

    # Network definition
    if backbone == "xception":
        net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
        )
    elif backbone == "resnet":
        # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    if gpu_id >= 0:
        net.cuda()

    # net load weights
    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no model load !!!!!!!!")

    ## multi scale
    scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
    testloader_list = []
    testloader_flip_list = []
    for pv in scale_list:
        composed_transforms_ts = transforms.Compose(
            [tr.Scale_(pv),
             tr.Normalize_xception_tf(),
             tr.ToTensor_()])

        composed_transforms_ts_flip = transforms.Compose([
            tr.Scale_(pv),
            tr.HorizontalFlip(),
            tr.Normalize_xception_tf(),
            tr.ToTensor_(),
        ])

        voc_val = cihp.VOCSegmentation(split="test",
                                       transform=composed_transforms_ts)
        voc_val_f = cihp.VOCSegmentation(split="test",
                                         transform=composed_transforms_ts_flip)

        testloader = DataLoader(voc_val,
                                batch_size=1,
                                shuffle=False,
                                num_workers=p["num_workers"])
        testloader_flip = DataLoader(voc_val_f,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=p["num_workers"])

        testloader_list.append(copy.deepcopy(testloader))
        testloader_flip_list.append(copy.deepcopy(testloader_flip))

    print("Eval Network")

    if not os.path.exists(opts.output_path + "cihp_output_vis/"):
        os.makedirs(opts.output_path + "cihp_output_vis/")
    if not os.path.exists(opts.output_path + "cihp_output/"):
        os.makedirs(opts.output_path + "cihp_output/")

    start_time = timeit.default_timer()
    # One testing epoch
    total_iou = 0.0
    net.eval()
    for ii, large_sample_batched in enumerate(
            zip(*testloader_list, *testloader_flip_list)):
        print(ii)
        # 1 0.5 0.75 1.25 1.5 1.75 ; flip:
        sample1 = large_sample_batched[:6]
        sample2 = large_sample_batched[6:]
        for iii, sample_batched in enumerate(zip(sample1, sample2)):
            inputs, labels = sample_batched[0]["image"], sample_batched[0][
                "label"]
            inputs_f, _ = sample_batched[1]["image"], sample_batched[1][
                "label"]
            inputs = torch.cat((inputs, inputs_f), dim=0)
            if iii == 0:
                _, _, h, w = inputs.size()
            # assert inputs.size() == inputs_f.size()

            # Forward pass of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=False), Variable(labels)

            with torch.no_grad():
                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()
                # outputs = net.forward(inputs)
                # pdb.set_trace()
                outputs = net.forward(inputs, adj1_test.cuda(),
                                      adj3_test.cuda(), adj2_test.cuda())
                outputs = (outputs[0] +
                           flip(flip_cihp(outputs[1]), dim=-1)) / 2
                outputs = outputs.unsqueeze(0)

                if iii > 0:
                    outputs = F.upsample(outputs,
                                         size=(h, w),
                                         mode="bilinear",
                                         align_corners=True)
                    outputs_final = outputs_final + outputs
                else:
                    outputs_final = outputs.clone()
        ################ plot pic
        predictions = torch.max(outputs_final, 1)[1]
        prob_predictions = torch.max(outputs_final, 1)[0]
        results = predictions.cpu().numpy()
        prob_results = prob_predictions.cpu().numpy()
        vis_res = decode_labels(results)

        parsing_im = Image.fromarray(vis_res[0])
        parsing_im.save(opts.output_path +
                        "cihp_output_vis/{}.png".format(img_list[ii][:-1]))
        cv2.imwrite(
            opts.output_path + "cihp_output/{}.png".format(img_list[ii][:-1]),
            results[0, :, :],
        )
        # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
        # pred_list.append(predictions.cpu())
        # label_list.append(labels.squeeze(1).cpu())
        # loss = criterion(outputs, labels, batch_average=True)
        # running_loss_ts += loss.item()

        # total_iou += utils.get_iou(predictions, labels)
    end_time = timeit.default_timer()
    print("time use for " + str(ii) + " is :" + str(end_time - start_time))

    # Eval
    pred_path = opts.output_path + "cihp_output/"
    eval_(
        pred_path=pred_path,
        gt_path=opts.gt_path,
        classes=opts.classes,
        txt_file=opts.txt_file,
    )
Ejemplo n.º 3
0
def main(opts):
    '''
  Namespace(
    batch=1, 
    classes=7, 
    dataset='pascal', 
    epochs=100, 
    gpus=1, 
    gt_path='./data/datasets/pascal/SegmentationPart/',
    hidden_graph_layers=256, 
    hidden_layers=128, 
    loadmodel='', 
    lr=1e-07, 
    numworker=12, 
    output_path='./result/gpm_ml_pascal', 
    resume_model='./data/models/GPM-ML_finetune_PASCAL.pth',
    step=30, 
    testepoch=10, 
    txt_file='./data/datasets/pascal/list/val_id.txt'
    )
  '''

    opts = edict()
    opts.batch = 1
    opts.classes = 7
    opts.dataset = 'pascal'
    opts.epochs = 100
    opts.gpus = 1
    opts.gt_path = '/nethome/hkwon64/Research/imuTube/repos_v2/human_parsing/Grapy-ML/data/datasets/pascal/SegmentationPart/'
    opts.hidden_graph_layers = 256
    opts.hidden_layers = 128
    opts.loadmodel = ''
    opts.lr = 1e-07
    opts.numworker = 12
    opts.output_path = '/nethome/hkwon64/Research/imuTube/repos_v2/human_parsing/Grapy-ML/result/gpm_ml_demo'
    opts.resume_model = '/nethome/hkwon64/Research/imuTube/repos_v2/human_parsing/Grapy-ML/data/models/GPM-ML_finetune_PASCAL.pth'
    opts.step = 30
    opts.testepoch = 10
    opts.txt_file = '/nethome/hkwon64/Research/imuTube/repos_v2/human_parsing/Grapy-ML/data/datasets/pascal/list/val_id.txt'

    with open(opts.txt_file, 'r') as f:
        img_list = f.readlines()

    net = grapy_net.GrapyMutualLearning(os=16,
                                        hidden_layers=opts.hidden_graph_layers)

    if gpu_id >= 0:
        net.cuda()

    if not opts.resume_model == '':
        x = torch.load(opts.resume_model)
        net.load_state_dict(x)

        print('resume model:', opts.resume_model)

    else:
        print('we are not resuming from any model')

    if opts.dataset == 'cihp':
        val = cihp.VOCSegmentation
        val_flip = cihp.VOCSegmentation

        vis_dir = '/cihp_output_vis/'
        mat_dir = '/cihp_output/'

        num_dataset_lbl = 0

    elif opts.dataset == 'pascal':

        val = pascal.VOCSegmentation
        val_flip = pascal.VOCSegmentation

        vis_dir = '/pascal_output_vis/'
        mat_dir = '/pascal_output/'

        num_dataset_lbl = 1

    elif opts.dataset == 'atr':
        val = atr.VOCSegmentation
        val_flip = atr.VOCSegmentation

        vis_dir = '/atr_output_vis/'
        mat_dir = '/atr_output/'

        print("atr_num")
        num_dataset_lbl = 2

    ## multi scale
    scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
    testloader_list = []
    testloader_flip_list = []
    for pv in scale_list:
        composed_transforms_ts = transforms.Compose(
            [tr.Scale_(pv),
             tr.Normalize_xception_tf(),
             tr.ToTensor_()])

        composed_transforms_ts_flip = transforms.Compose([
            tr.Scale_(pv),
            tr.HorizontalFlip(),
            tr.Normalize_xception_tf(),
            tr.ToTensor_()
        ])

        # voc_val = val(split='val', transform=composed_transforms_ts)
        # voc_val_f = val_flip(split='val', transform=composed_transforms_ts_flip)

        # testloader = DataLoader(voc_val, batch_size=1, shuffle=False, num_workers=4)
        # testloader_flip = DataLoader(voc_val_f, batch_size=1, shuffle=False, num_workers=4)

        testloader_list.append(composed_transforms_ts)
        testloader_flip_list.append(composed_transforms_ts_flip)

    print("Eval Network")

    start_time = timeit.default_timer()
    # One testing epoch
    total_iou = 0.0

    c1, c2, p1, p2, a1, a2 = [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],\
                 [[0], [1, 2, 4, 13], [5, 6, 7, 10, 11, 12], [3, 14, 15], [8, 9, 16, 17, 18, 19]], \
                 [[0], [1, 2, 3, 4, 5, 6]], [[0], [1], [2], [3, 4], [5, 6]], [[0], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]],\
                 [[0], [1, 2, 3, 11], [4, 5, 7, 8, 16, 17], [14, 15], [6, 9, 10, 12, 13]]

    net.set_category_list(c1, c2, p1, p2, a1, a2)

    net.eval()

    # load image
    if 1:
        if 0:
            im_name = 'demo.jpg'
            dir_frames = f'/nethome/hkwon64/Research/imuTube/repos_v2/human_parsing/Self-Correction-Human-Parsing/mhp_extension/data/DemoDataset/global_pic/'
            list_frame = [im_name]

        if 0:
            dir_name = 'freeweights'

            # One-Arm_Dumbbell_Row/4IoyUvtF7do/3.000_39.188
            class_name = 'One-Arm_Dumbbell_Row'
            vid = '4IoyUvtF7do'
            clip = '3.000_39.188'
            dir_fw = '/nethome/hkwon64/Research/imuTube/dataset/imutube_v2'
            dir_frames = dir_fw + f'/{dir_name}/{class_name}/{vid}/{clip}/frames'

            list_frame = os.listdir(dir_frames)
            list_frame = [item for item in list_frame if item[-4:] == '.png']
            list_frame.sort()

        if 0:
            dir_name = 'freeweights'

            # Incline_Dumbbell_Press/4UZ8G8eW5MU/17.000_36.726
            class_name = 'Incline_Dumbbell_Press'
            vid = '4UZ8G8eW5MU'
            clip = '17.000_36.726'

        if 1:
            dir_name = 'freeweights'

            # One-Arm_Dumbbell_Row/Hfxxc4zg5zs/138.544_231.846
            class_name = 'One-Arm_Dumbbell_Row'
            vid = 'Hfxxc4zg5zs'
            clip = '138.544_231.846'

        sample_info = f'{class_name}/{vid}/{clip}'

        dir_fw = '/nethome/hkwon64/Research/imuTube/dataset/imutube_v2'
        dir_clip = dir_fw + f'/{dir_name}/{sample_info}'

        dir_frames = dir_clip + f'/frames'

        list_frame = os.listdir(dir_frames)
        list_frame = [item for item in list_frame if item[-4:] == '.png']
        list_frame.sort()

        dir_pose2d = dir_clip + '/pose2D'
        file_ap = dir_pose2d + '/alphapose-results.json'

        ap_results = json.load(open(file_ap))
        print('load from ...', file_ap)
        # pprint (ap_results)
        # assert False

        frame_info = {}
        for result in ap_results:
            t = int(os.path.splitext(result['image_id'])[0])
            if t not in frame_info:
                frame_info[t] = {
                    'pIDs': [],
                    'kps': [],
                    'score': [],
                    'bbox': [],
                    'bbox_expand': []
                }

            # idx = result['idx']
            # if type(idx) is list:
            #   idx = idx[0]
            #   if type(idx) is list:
            #     idx = idx[0]
            # # if len (idx) > 1:
            # # 	pprint (result)
            # # 	print(len(result['keypoints']))
            # # 	assert False
            # # print (idx)
            # # assert False
            # frame_info[t]['idx'].append(idx)

            kps = np.array(result['keypoints']).reshape((-1, 3))
            frame_info[t]['kps'].append(kps)

            _p_score = result['score']
            frame_info[t]['score'].append(_p_score)

            # get maximal bbox
            start_point = np.amin(kps[:, :2], axis=0).astype(int)
            end_point = np.amax(kps[:, :2], axis=0).astype(int)

            x1, y1, w, h = result['box']
            if x1 < start_point[0]:
                start_point[0] = int(x1)
            if y1 < start_point[1]:
                start_point[1] = int(y1)
            if x1 + w > end_point[0]:
                end_point[0] = int(x1 + w)
            if y1 + h > end_point[1]:
                end_point[1] = int(y1 + h)

            x_min, y_min = start_point
            x_max, y_max = end_point
            bbox = np.array([x_min, y_min, x_max, y_max])
            frame_info[t]['bbox'].append(bbox)

            # # get expanded bbox
            # exp_x_min, exp_y_min, exp_x_max, exp_y_max = func_bbox_expand(cfg.video.h, cfg.video.w, bbox, exp_ratio)

            # if exp_x_min == 0 \
            # and exp_y_min == 0 \
            # and exp_x_max == cfg.video.w-1 \
            # and exp_y_max == cfg.video.h-1:
            #   # print (f'{dir_clip} [{t}] fills whole image')
            #   frame_info[t]['bbox_expand'].append(bbox)
            # else:
            #   frame_info[t]['bbox_expand'].append([exp_x_min, exp_y_min, exp_x_max, exp_y_max])

        vis_dir = f'/pascal_{dir_name}_vis/'
        mat_dir = f'/pascal_{dir_name}/'

    else:
        # im_name = '2008_000003.jpg'
        # im_name = '2008_000008.jpg'
        # im_name = '2008_000026.jpg'
        im_name = '2008_000041.jpg'
        # im_name = '2008_000034.jpg'
        dir_frames = '/nethome/hkwon64/Research/imuTube/repos_v2/human_parsing/Grapy-ML/data/datasets/pascal/JPEGImages/'
        list_frame = [im_name]

    if not os.path.exists(opts.output_path + vis_dir):
        os.makedirs(opts.output_path + vis_dir)
    if not os.path.exists(opts.output_path + mat_dir):
        os.makedirs(opts.output_path + mat_dir)

    exp_ratio = 1.2

    with torch.no_grad():

        for t, im_name in enumerate(list_frame):
            t = 279
            im_name = list_frame[t]
            file_input = dir_frames + f'/{im_name}'

            _img = Image.open(file_input).convert('RGB')  # return is RGB pic
            w, h = _img.size

            pID = 2
            bbox = frame_info[t]['bbox'][pID]
            exp_x_min, exp_y_min, exp_x_max, exp_y_max = func_bbox_expand(
                h, w, bbox, exp_ratio)
            bbox_expand = [exp_x_min, exp_y_min, exp_x_max, exp_y_max]

            x_min, y_min, x_max, y_max = bbox_expand
            kps = frame_info[t]['kps'][pID]
            # kps[:,:2] -= np.array([[x_min, y_min]])

            sample1 = []
            for composed_transforms_ts in testloader_list:
                _img = Image.open(file_input).convert(
                    'RGB')  # return is RGB pic d
                _img = _img.crop(bbox_expand)

                if 0:
                    w, h = _img.size
                    ow = int(w * 0.5)
                    oh = int(h * 0.5)
                    _img = _img.resize((ow, oh), Image.BILINEAR)
                    # print (_img.size)
                    # assert False

                _img = composed_transforms_ts({'image': _img})
                sample1.append(_img)

            sample2 = []
            for composed_transforms_ts_flip in testloader_flip_list:
                _img = Image.open(file_input).convert(
                    'RGB')  # return is RGB pic
                _img = _img.crop(bbox_expand)

                if 0:
                    w, h = _img.size
                    ow = int(w * 0.5)
                    oh = int(h * 0.5)
                    _img = _img.resize((ow, oh), Image.BILINEAR)
                    # print (_img.size)
                    # assert False

                _img = composed_transforms_ts_flip({'image': _img})
                sample2.append(_img)

            # print(ii)
            #1 0.5 0.75 1.25 1.5 1.75 ; flip:
            # sample1 = large_sample_batched[:6]
            # sample2 = large_sample_batched[6:]

            # for iii,sample_batched in enumerate(zip(sample1,sample2)):
            # 	print (sample_batched[0]['image'].shape)
            # 	print (sample_batched[1]['image'].shape)
            # assert False

            for iii, sample_batched in enumerate(zip(sample1, sample2)):
                # print (sample_batched[0]['image'].shape)
                # print (sample_batched[1]['image'].shape)

                inputs = sample_batched[0]['image']
                inputs_f = sample_batched[1]['image']
                inputs = torch.cat((inputs, inputs_f), dim=0)

                if iii == 0:
                    _, _, h, w = inputs.size()
                # assert inputs.size() == inputs_f.size()

                # Forward pass of the mini-batch
                inputs = Variable(inputs, requires_grad=False)

                with torch.no_grad():
                    if gpu_id >= 0:
                        inputs = inputs.cuda()
                    # outputs = net.forward(inputs)
                    # pdb.set_trace()
                    outputs, outputs_aux = net.forward(
                        (inputs, num_dataset_lbl), training=False)

                    # print(outputs.shape, outputs_aux.shape)
                    if opts.dataset == 'cihp':
                        outputs = (outputs[0] +
                                   flip(flip_cihp(outputs[1]), dim=-1)) / 2
                    elif opts.dataset == 'pascal':
                        outputs = (outputs[0] + flip(outputs[1], dim=-1)) / 2
                    else:
                        outputs = (outputs[0] +
                                   flip(flip_atr(outputs[1]), dim=-1)) / 2

                    outputs = outputs.unsqueeze(0)

                    if iii > 0:
                        outputs = F.upsample(outputs,
                                             size=(h, w),
                                             mode='bilinear',
                                             align_corners=True)
                        outputs_final = outputs_final + outputs
                    else:
                        outputs_final = outputs.clone()

            ################ plot pic
            predictions = torch.max(outputs_final, 1)[1]
            prob_predictions = torch.max(outputs_final, 1)[0]
            results = predictions.cpu().numpy()
            # print (np.unique(results))
            # assert False

            prob_results = prob_predictions.cpu().numpy()
            vis_res = decode_labels(results)

            dir_im = opts.output_path + vis_dir + f'/{sample_info}'
            os.makedirs(dir_im, exist_ok=True)
            dir_mat = opts.output_path + mat_dir + f'/{sample_info}'
            os.makedirs(dir_im, exist_ok=True)
            parsing_im = Image.fromarray(vis_res[0])
            parsing_im.save(dir_im + f'/{im_name}.png')
            cv2.imwrite(dir_mat + f'{im_name}', results[0, :, :])
            print('save in ...', dir_mat + f'{im_name}')

            # draw mask
            img = cv2.imread(file_input)
            print('load from ...', file_input)
            # img = img[y_min:y_max, x_min:x_max]

            for c in range(1, len(classes)):
                mask = results[0, :, :] == c
                _mask = np.zeros(img.shape[:2], dtype=bool)
                _mask[y_min:y_max, x_min:x_max] = mask
                img = draw_mask(img, _mask, thickness=3, color=colors[c - 1])

            img = draw_skeleton(img, kps)

            dir_mask = dir_im + '_mask'
            os.makedirs(dir_mask, exist_ok=True)

            file_mask = dir_mask + f'/{im_name}_mask.png'
            cv2.imwrite(file_mask, img)
            print('save in ...', file_mask)
            assert False

            # total_iou += utils.get_iou(predictions, labels)
        end_time = timeit.default_timer()
        print('time use for ' + f'{im_name}' + ' is :' +
              str(end_time - start_time))