def validate(val_loader,
             model,
             criterion,
             num_classes,
             checkpoint,
             debug=False,
             flip=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acces = AverageMeter()

    # switch to evaluate mode
    model.eval()

    gt_win, pred_win = None, None
    iou = None
    end = time.time()
    bar = Bar('Eval ', max=len(val_loader))
    with torch.no_grad():
        for i, (input, input_depth, input_mask, target,
                meta) in enumerate(val_loader):
            # if RELABEL and i == 2 : break

            # measure data loading time
            data_time.update(time.time() - end)

            input, input_mask, target = input.to(device), input_mask.to(
                device), target.to(device, non_blocking=True)
            input_depth = input_depth.to(device)

            batch_size = input.shape[0]
            loss = 0
            last_state = None
            acc_list = []

            # compute use TSM feature
            for j in range(6):
                input_now = input[:, j]  # [B, 3, 256, 256]
                input_depth_now = input_depth[:, j]
                input_mask_now = input_mask[:, j]
                target_now = target[:, j]
                if j == 0:
                    output, output_state = model(
                        torch.cat((input_now, input_depth_now, input_mask_now),
                                  1))
                else:
                    if LSTM_STATE == 'stateness':
                        output, output_state = model(
                            torch.cat(
                                (input_now, input_depth_now, input_mask_now),
                                1))
                    elif LSTM_STATE == 'stateful':
                        output, output_state = model(torch.cat(
                            (input_now, input_depth_now, input_mask_now), 1),
                                                     input_state=last_state)
                last_state = output_state

                # print(output.shape)

                round_output = torch.round(output).float()
                loss += criterion(output, target_now)

                temp_acc = float(
                    (round_output == target_now).sum()) / batch_size
                acc_list.append(temp_acc)

                round_output = round_output.cpu()
                # print(round_output)

                if RELABEL:
                    # save in same checkpoint
                    raw_mask_path = meta['mask_path_list'][j][0]
                    img_index = meta['image_index_list'][j][0]
                    temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                    temp_tail = ('/').join(raw_mask_path.split('/')[-5:])
                    temp = os.path.join(temp_head, 'code/train_two_steps',
                                        checkpoint, 'pred_vis', temp_tail)
                    relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                    relabel_mask_dir = os.path.dirname(relabel_mask_dir)

                    raw_mask_rgb_path = os.path.join(
                        os.path.dirname(os.path.dirname(raw_mask_path)),
                        'first_mask_rgb', relabel_mask_name)
                    new_mask_rgb_path = os.path.join(relabel_mask_dir,
                                                     'gt_' + relabel_mask_name)
                    raw_rgb_frame_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'raw_frames', \
                        relabel_mask_name[:-4] + '.png')

                    from PIL import Image
                    import numpy as np
                    if os.path.exists(raw_mask_rgb_path):
                        gt_mask_rgb = np.array(Image.open(raw_mask_rgb_path))
                    else:
                        gt_mask_rgb = np.array(Image.open(raw_rgb_frame_path))

                    if not isdir(relabel_mask_dir):
                        mkdir_p(relabel_mask_dir)

                    gt_label_str = None
                    pred_label_str = None

                    if target_now[0][0] == 0:
                        gt_label_str = "GT : False"
                    elif target_now[0][0] == 1:
                        gt_label_str = "GT : True"

                    if round_output[0][0] == 0:
                        pred_label_str = "Pred : False"
                    elif round_output[0][0] == 1:
                        pred_label_str = "Pred : True"
                    output_str = gt_label_str + '. ' + pred_label_str

                    # if target_now[0][0] != round_output[0][0] :
                    #     print(raw_rgb_frame_path)

                    if not gt_win:
                        plt.plot()
                        plt.title(output_str)
                        gt_win = plt.imshow(gt_mask_rgb)
                    else:
                        plt.title(output_str)
                        gt_win.set_data(gt_mask_rgb)

                    plt.plot()
                    index_name = "%05d.jpg" % (img_index)
                    plt.savefig(
                        os.path.join(relabel_mask_dir, 'vis_' + index_name))

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))
            acces.update(sum(acc_list) / len(acc_list), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f}'.format(
                batch=i + 1,
                size=len(val_loader),
                data=data_time.val,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg)
            bar.next()
        bar.finish()
    return losses.avg, acces.avg
예제 #2
0
def main(args):
    global best_iou
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume

    if args.pre_train:
        # pre train lr = 5e-4
        args.lr = 5e-5

    if args.resume != '' and args.pre_train == False:
        args.checkpoint = ('/').join(args.resume.split('/')[:2])

    if args.relabel == True:
        args.test_batch = 1
    if args.test == True:
        args.train_batch = 4
        args.test_batch = 4
        args.epochs = 20

    if args.evaluate and args.relabel == False:
        args.test_batch = 1

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    # idx is the index of joints used to compute accuracy
    # if args.dataset in ['mpii', 'lsp']:
    #     idx = [1,2,3,4,5,6,11,12,15,16]
    # elif args.dataset == 'coco':
    #     idx = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17]
    # elif args.dataset == 'sad' or args.dataset == 'sad_step_1':
    #     idx = [1] # support affordance
    # else:
    #     print("Unknown dataset: {}".format(args.dataset))
    #     assert False
    idx = [1]

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    # 2020.6.7
    # freeze feature extraction and first one hg model paras
    # freeze_list = [model.conv1, model.bn1, model.layer1, model.layer2, model.layer3, \
    #     model.hg[0], model.res[0], model.fc[0], \
    #     model.score[0],model.fc_[0], model.score_[0]]
    # for freeze_layer in freeze_list :
    #     for param in freeze_layer.parameters():
    #         param.requires_grad = False

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.IoULoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
        # optimizer = torch.optim.Adam(
        #     filter(lambda p: p.requires_grad, model.parameters()),
        #     lr=args.lr,
        # )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.pre_train:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            # start from epoch 0
            args.start_epoch = 0
            best_iou = 0
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(
                ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val IoU'])
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # best_iou = checkpoint['best_iou']

            # start from epoch 0
            args.start_epoch = 0
            best_iou = 0

            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val IoU'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    '''
    for i, (input, input_depth, target, meta) in enumerate(train_loader):
        print(len(input))
        print(input[0].shape)
        print(input_depth[0].shape)
        print(target[0].shape)
        return
    '''

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            new_checkpoint = 'checkpoint_0701_bbox_hide'
            mkdir_p(new_checkpoint)
            loss, iou, predictions = validate(val_loader, model, criterion,
                                              njoints, new_checkpoint,
                                              args.debug, args.flip)
            # Because test and val are all considered -> iou is uesless
            # print("Val IoU: %.3f" % (iou))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        if args.debug:
            print('Draw pred /gt heatmap')
        JUST_EVALUATE = True
        new_checkpoint = 'checkpoint_0701_bbox_hide'
        mkdir_p(new_checkpoint)
        loss, iou, predictions = validate(val_loader, model, criterion,
                                          njoints, new_checkpoint, args.debug,
                                          args.flip)
        print("Val IoU: %.3f" % (iou))
        return

    ## backup when training starts
    code_backup_dir = 'code_backup'
    mkdir_p(os.path.join(args.checkpoint, code_backup_dir))
    os.system('cp ../affordance/models/hourglass.py %s/%s/hourglass.py' %
              (args.checkpoint, code_backup_dir))
    os.system('cp ../affordance/datasets/sad.py %s/%s/sad.py' %
              (args.checkpoint, code_backup_dir))
    this_file_name = os.path.split(os.path.abspath(__file__))[1]
    os.system('cp ./%s %s' %
              (this_file_name,
               os.path.join(args.checkpoint, code_backup_dir, this_file_name)))

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer,
                           args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_iou, predictions = validate(val_loader, model,
                                                      criterion, njoints,
                                                      arg.checkpoint,
                                                      args.debug, args.flip)
        print("Val IoU: %.3f" % (valid_iou))

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, valid_iou])

        # remember best acc and save checkpoint
        is_best_iou = valid_iou > best_iou
        best_iou = max(valid_iou, best_iou)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
            },
            is_best_iou,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()

    print("Best iou = %.3f" % (best_iou))
def main(args):
    global best_iou
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume
    if args.resume != '':
        args.checkpoint = ('/').join(args.resume.split('/')[:2])

    if args.relabel == True:
        args.test_batch = 1
    elif args.test == True:
        args.train_batch = 1
        args.test_batch = 1
        args.epochs = 20
        # args.train_batch = 2
        # args.test_batch = 2
        # args.epochs = 10

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    elif args.dataset == 'sad' or args.dataset == 'sad_step_2' or args.dataset == 'sad_step_2_eval':
        idx = [1]  # support affordance
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.BCELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    # for i, (input, input_depth, input_mask, target, meta) in enumerate(train_loader):
    #     print(len(input))
    #     print(input[0].shape)
    #     print(input_mask[0].shape)
    #     print(target[0].shape)
    #     return

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            loss, acc = validate(val_loader, model, criterion, njoints,
                                 args.checkpoint, args.debug, args.flip)
            print("Val acc: %.3f" % (acc))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        JUST_EVALUATE = True
        loss, acc = validate(val_loader, model, criterion, njoints,
                             args.checkpoint, args.debug, args.flip)
        print("Val acc: %.3f" % (acc))
        return

    ## backup when training starts
    code_backup_dir = 'code_backup'
    mkdir_p(os.path.join(args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/models/affordance_classification.py %s/%s/affordance_classification.py'
        % (args.checkpoint, code_backup_dir))
    os.system('cp ../affordance/models/convlstm.py %s/%s/convlstm.py' %
              (args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/datasets/sad_step_2_eval.py %s/%s/sad_step_2_eval.py'
        % (args.checkpoint, code_backup_dir))
    this_file_name = os.path.split(os.path.abspath(__file__))[1]
    os.system('cp ./%s %s' %
              (this_file_name,
               os.path.join(args.checkpoint, code_backup_dir, this_file_name)))

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer,
                           args.debug, args.flip)

        # evaluate on validation set
        valid_loss, valid_acc = validate(val_loader, model, criterion, njoints,
                                         args.checkpoint, args.debug,
                                         args.flip)
        print("Val acc: %.3f" % (valid_acc))

        # append logger file
        logger.append([epoch + 1, lr, train_loss, valid_loss, valid_acc])

        # remember best acc and save checkpoint
        is_best_iou = valid_acc > best_iou
        best_iou = max(valid_acc, best_iou)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
            },
            is_best_iou,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()

    print("Best acc = %.3f" % (best_iou))
def main(args):
    global best_final_acc
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume

    if args.pre_train:
        # pre train lr = 5e-4
        args.lr = 5e-5

    if args.resume != '' and args.pre_train == False:
        args.checkpoint = ('/').join(args.resume.split('/')[:2])
    if args.relabel == True:
        args.test_batch = 1
    if args.test == True:
        args.train_batch = 4
        args.test_batch = 4
        args.epochs = 10

    if args.evaluate and args.relabel == False:
        args.test_batch = 4

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    idx = [1]

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion_iou = losses.IoULoss().to(device)
    criterion_bce = losses.BCELoss().to(device)
    criterion_focal = losses.FocalLoss().to(device)
    criterions = [criterion_iou, criterion_bce, criterion_focal]

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.pre_train:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)

            # start from epoch 0
            args.start_epoch = 0
            best_final_acc = 0
            model.load_state_dict(checkpoint['state_dict'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
            logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention Loss', \
                'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \
                'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc'])

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    elif args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            # best_iou = checkpoint['best_iou']

            # start from epoch 0
            args.start_epoch = 0
            best_final_acc = 0

            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Attention Loss', 'Val Attention Loss', 'Val Attention IoU', \
            'Train Region Loss', 'Val Region Loss', 'Val Region IoU', \
            'Train Existence Acc', 'Val Existence Loss', 'Val Existence Acc', 'Val final acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    '''
    for i, (input, input_depth, target_heatmap, target_mask, target_label, meta) in enumerate(train_loader):
        print(len(input))
        print(input[0].shape)
        print(input_depth[0].shape)
        print(target_heatmap[0].shape)
        print(target_mask[0].shape)
        print(target_label[0].shape)
        return
    '''

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
                val_existence_loss, val_existence_acc , val_final_acc \
                    = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip)
            print("Val final acc: %.3f" % (val_final_acc))
            # Because test and val are all considered -> iou is uesless
            # print("Val IoU: %.3f" % (iou))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        JUST_EVALUATE = True
        val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
            val_existence_loss, val_existence_acc , val_final_acc \
                = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip)
        print("Val final acc: %.3f" % (val_final_acc))
        # print( val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
        #     val_existence_loss, val_existence_acc , val_final_acc)

        return

    ## backup when training starts
    code_backup_dir = 'code_backup'
    mkdir_p(os.path.join(args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/models/hourglass_final.py %s/%s/hourglass_final.py' %
        (args.checkpoint, code_backup_dir))
    os.system(
        'cp ../affordance/datasets/sad_attention.py %s/%s/sad_attention.py' %
        (args.checkpoint, code_backup_dir))
    this_file_name = os.path.split(os.path.abspath(__file__))[1]
    os.system('cp ./%s %s' %
              (this_file_name,
               os.path.join(args.checkpoint, code_backup_dir, this_file_name)))

    # train and eval
    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_att_loss, train_region_loss, train_existence_loss \
            = train(train_loader, model, criterions, optimizer, args.debug, args.flip)

        # evaluate on validation set
        val_att_loss, val_att_iou, val_region_loss, val_region_iou, \
            val_existence_loss, val_existence_acc , val_final_acc \
                = validate(val_loader, model, criterions, njoints, args.checkpoint, args.debug, args.flip)
        print("Val region IoU: %.3f" % (val_region_iou))
        print("Val label acc: %.3f" % (val_existence_acc))
        val_final_acc = val_region_iou + val_existence_acc

        # append logger file
        logger.append([epoch + 1, lr, train_att_loss, val_att_loss, val_att_iou, \
            train_region_loss, val_region_loss, val_region_iou, \
            train_existence_loss, val_existence_loss, val_existence_acc,
            val_final_acc])

        # remember best acc and save checkpoint
        is_best_acc = val_final_acc > best_final_acc
        best_final_acc = max(val_final_acc, best_final_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_iou': best_final_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best_acc,
            checkpoint=args.checkpoint,
            snapshot=args.snapshot)

    logger.close()

    print("Best val final acc = %.3f" % (best_final_acc))
예제 #5
0
def validate(val_loader,
             model,
             criterion,
             num_classes,
             checkpoint,
             debug=False,
             flip=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    ioues = AverageMeter()

    # predictions
    predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)

    # switch to evaluate mode
    model.eval()

    gt_win, pred_win = None, None
    iou = None
    end = time.time()
    bar = Bar('Eval ', max=len(val_loader))
    with torch.no_grad():
        for i, (input, input_depth, target, meta) in enumerate(val_loader):
            if RELABEL and i == 1: break

            # measure data loading time
            data_time.update(time.time() - end)

            input = input.to(device, non_blocking=True)
            input_depth = input_depth.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            batch_size = input.shape[0]
            loss = 0
            iou_list = []

            # store first two stack feature # 2 = first and second stack, 6 = video length
            video_feature_cache = torch.zeros(batch_size, 6, 2, 256,
                                              output_res, output_res)

            # first compute
            for j in range(6):
                input_now = input[:, j]  # [B, 3, 256, 256]
                input_depth_now = input_depth[:, j]
                target_now = target[:, j]
                _, out_tsm_feature = model(
                    torch.cat((input_now, input_depth_now),
                              1))  # [B, 4, 256, 256]
                for k in range(2):
                    video_feature_cache[:, j, k] = out_tsm_feature[k]

            # TSM module
            b, t, _, c, h, w = video_feature_cache.size()
            fold_div = 8
            fold = c // fold_div
            new_tsm_feature = torch.zeros(batch_size, 6, 2, 256, output_res,
                                          output_res)
            for j in range(2):
                x = video_feature_cache[:, :, j]
                temp = torch.zeros(batch_size, 6, 256, output_res, output_res)
                temp[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
                temp[:, 1:, fold:2 * fold] = x[:, :-1,
                                               fold:2 * fold]  # shift right
                temp[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift
                new_tsm_feature[:, :, j] = temp

            new_tsm_feature = new_tsm_feature.to(device)
            # compute use TSM feature
            for j in range(6):
                input_now = input[:, j]  # [B, 3, 256, 256]
                input_depth_now = input_depth[:, j]
                target_now = target[:, j]
                output, _ = model(torch.cat((input_now, input_depth_now), 1),
                                  True, new_tsm_feature[:, j])

                if type(
                        output
                ) == list:  # multiple output # beacuse of intermediate prediction
                    for o in output:
                        loss += criterion(o, target_now)
                    output = output[-1]
                else:  # single output
                    pass
                '''
                testing now
                '''
                output = output.cpu()
                target_now = target_now.cpu()

                raw_mask_path = meta['mask_path_list'][j][0]
                img_index = meta['image_index_list'][j][0]
                temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                temp_tail = ('/').join(raw_mask_path.split('/')[-6:])
                temp = os.path.join(temp_head, 'code/train_two_steps',
                                    checkpoint, 'pred_vis', temp_tail)
                relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                relabel_mask_dir = os.path.dirname(relabel_mask_dir)
                area_head = '/home/s5078345/Affordance-Detection-on-Video/faster-rcnn.pytorch/data_affordance_bbox'
                area_tail = ('/').join(raw_mask_path.split('/')[6:10])
                area_to_detect_data_path = os.path.join(
                    area_head, area_tail, relabel_mask_name[:-4] + '.txt')
                area_to_detect_list = []
                with open(area_to_detect_data_path) as f:
                    for line in f:
                        inner_list = [
                            int(elt.strip()) for elt in line.split(' ')
                        ]
                        # in alternative, if you need to use the file content as numbers
                        # inner_list = [int(elt.strip()) for elt in line.split(',')]
                        area_to_detect_list.append(inner_list)
                if len(area_to_detect_list) == 0:
                    area_to_detect_list = None

                area_to_detect = area_to_detect_list
                if area_to_detect is not None:
                    out_resized_area = torch.zeros((1, 1, 64, 64))
                    gt_resized_area = torch.zeros((1, 1, 64, 64))
                    for _i in range(len(area_to_detect)):
                        x_min, y_min, x_max, y_max = area_to_detect[_i]

                        x_min = math.floor(x_min / 640 * 64)
                        y_min = math.floor(y_min / 480 * 64)
                        x_max = math.ceil(x_max / 640 * 64)
                        y_max = math.ceil(y_max / 480 * 64)

                        # clip pred
                        out_resized_area[0, 0, y_min:y_max,
                                         x_min:x_max] = output[0, 0,
                                                               y_min:y_max,
                                                               x_min:x_max]

                        # clip GT
                        gt_resized_area[0, 0, y_min:y_max,
                                        x_min:x_max] = target_now[0, 0,
                                                                  y_min:y_max,
                                                                  x_min:x_max]

                    output = out_resized_area
                    target_now = gt_resized_area

                else:
                    pass

                temp_iou = intersectionOverUnion(output.cpu(),
                                                 target_now.cpu(),
                                                 idx)  # have not tested
                iou_list.append(temp_iou)
                score_map = output[-1].cpu() if type(
                    output) == list else output.cpu()

                if RELABEL:
                    # save in same checkpoint
                    raw_mask_path = meta['mask_path_list'][j][0]
                    img_index = meta['image_index_list'][j][0]
                    temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                    temp_tail = ('/').join(raw_mask_path.split('/')[-6:])
                    temp = os.path.join(temp_head, 'code/train_two_steps',
                                        checkpoint, 'pred_vis', temp_tail)
                    relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                    relabel_mask_dir = os.path.dirname(relabel_mask_dir)

                    raw_mask_rgb_path = os.path.join(
                        os.path.dirname(os.path.dirname(raw_mask_path)),
                        'first_mask_rgb', relabel_mask_name)
                    new_mask_rgb_path = os.path.join(relabel_mask_dir,
                                                     'gt_' + relabel_mask_name)
                    raw_rgb_frame_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'raw_frames', \
                        relabel_mask_name[:-4] + '.png')

                    # print(relabel_mask_dir)
                    # print(relabel_mask_name)
                    from PIL import Image
                    import numpy as np
                    if os.path.exists(raw_mask_rgb_path):
                        gt_mask_rgb = np.array(Image.open(raw_mask_rgb_path))
                    else:
                        gt_mask_rgb = np.array(Image.open(raw_rgb_frame_path))
                    # print(input_now.shape)
                    # print(score_map) # [1, 1, 64, 64]

                    # 2020.7.1
                    # load info about hide area
                    # from faster rcnn
                    area_head = '/home/s5078345/Affordance-Detection-on-Video/faster-rcnn.pytorch/data_affordance_bbox'
                    area_tail = ('/').join(raw_mask_path.split('/')[6:10])
                    area_to_detect_data_path = os.path.join(
                        area_head, area_tail, relabel_mask_name[:-4] + '.txt')
                    # print(area_to_use_data_path)
                    area_to_detect_list = []
                    with open(area_to_detect_data_path) as f:
                        for line in f:
                            inner_list = [
                                int(elt.strip()) for elt in line.split(' ')
                            ]
                            # in alternative, if you need to use the file content as numbers
                            # inner_list = [int(elt.strip()) for elt in line.split(',')]
                            area_to_detect_list.append(inner_list)
                    if len(area_to_detect_list) == 0:
                        area_to_detect_list = None
                    # print(area_to_detect_list)
                    pred_batch_img, pred_mask = relabel_heatmap(
                        input_now,
                        score_map,
                        'pred',
                        area_to_detect=area_to_detect_list
                    )  # return an Image object

                    if not isdir(relabel_mask_dir):
                        mkdir_p(relabel_mask_dir)

                    if not gt_win or not pred_win:
                        ax1 = plt.subplot(121)
                        ax1.title.set_text('MASK_RGB_GT')
                        gt_win = plt.imshow(gt_mask_rgb)
                        ax2 = plt.subplot(122)
                        ax2.title.set_text('Mask_RGB_PRED')
                        pred_win = plt.imshow(pred_batch_img)
                    else:
                        gt_win.set_data(gt_mask_rgb)
                        pred_win.set_data(pred_batch_img)
                    plt.plot()
                    index_name = "%05d.jpg" % (img_index)
                    plt.savefig(
                        os.path.join(relabel_mask_dir, 'vis_' + index_name))
                    pred_mask.save(os.path.join(relabel_mask_dir, index_name))

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))
            # acces.update(acc[0], input.size(0))
            ioues.update(sum(iou_list) / len(iou_list), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f}'.format(
                batch=i + 1,
                size=len(val_loader),
                data=data_time.val,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg)
            bar.next()
        bar.finish()
    return losses.avg, ioues.avg, predictions
def validate(val_loader,
             model,
             criterions,
             num_classes,
             checkpoint,
             debug=False,
             flip=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    total_losses = AverageMeter()
    heatmap_losses = AverageMeter()
    mask_losses = AverageMeter()
    label_losses = AverageMeter()

    heatmap_ioues = AverageMeter()
    mask_ioues = AverageMeter()
    label_acces = AverageMeter()

    # iou > 50% and step 2 labels are both right -> correcct
    # if label is false (and pred is false too) -> correct
    final_acces = AverageMeter()

    # for statistic
    gt_trues = AverageMeter()  # positive
    gt_falses = AverageMeter()  # negative
    pred_trues = AverageMeter()  # true == true and iou > 50%
    pred_falses = AverageMeter()
    pred_trues_first = AverageMeter()  # true == true

    # Loss
    criterion_iou, criterion_bce, criterion_focal = criterions

    # switch to evaluate mode
    model.eval()

    gt_win, pred_win = None, None
    iou = None
    end = time.time()
    bar = Bar('Eval ', max=len(val_loader))
    with torch.no_grad():
        for i, (input, input_depth, target_heatmap, target_mask, target_label,
                meta) in enumerate(val_loader):
            # if RELABEL and i == 10 : break

            # measure data loading time
            data_time.update(time.time() - end)

            input, input_depth = input.to(device), input_depth.to(device)
            target_heatmap, target_mask, target_label = target_heatmap.to(device, non_blocking=True), target_mask.to(device, non_blocking=True), \
                target_label.to(device, non_blocking=True)

            batch_size = input.shape[0]
            total_loss, heatmap_loss, mask_loss, label_loss = 0, 0, 0, 0
            last_state = None
            last_tsm_buffer = None
            heatmap_iou_list, mask_iou_list, label_acc_list = [], [], []
            final_acc_list = []

            # for statistic
            gt_true_list, gt_false_list, pred_true_list, pred_false_list, pred_true_first_list = [], [], [], [], []

            for j in range(6):
                input_now = input[:, j]  # [B, 3, 256, 256]
                input_depth_now = input_depth[:, j]  # [B, 1, 256, 256]
                target_heatmap_now = target_heatmap[:, j]  # [B, 1, 64, 64]
                target_mask_now = target_mask[:, j]  # [B, 1, 64, 64]
                target_label_now = target_label[:, j]  # [B, 1]

                # print(target_label_now)

                if j == 0:
                    output_heatmap, output_mask, output_label, output_state, output_tsm = model(
                        torch.cat((input_now, input_depth_now), 1))
                else:
                    output_heatmap, output_mask, output_label, output_state, output_tsm = model(torch.cat((input_now, input_depth_now), 1), \
                        input_state = last_state, tsm_input = last_tsm_buffer)
                last_state = output_state
                last_tsm_buffer = output_tsm

                # Loss computation
                for o_heatmap in output_heatmap:
                    temp = criterion_iou(
                        o_heatmap, target_heatmap_now) * 0.05 + criterion_bce(
                            o_heatmap, target_heatmap_now) * 0.05
                    total_loss += temp
                    heatmap_loss += temp
                for o_mask in output_mask:
                    # temp = criterion_iou(o_mask, target_mask_now) + criterion_bce(o_mask, target_mask_now)
                    # temp = criterion_focal(o_mask, target_mask_now) # exp 1
                    temp = criterion_focal(
                        o_mask, target_mask_now) + criterion_iou(
                            o_mask, target_mask_now)  # exp 2
                    total_loss += temp
                    mask_loss += temp
                temp = criterion_bce(output_label, target_label_now)
                total_loss += temp
                label_loss += temp

                # choose last one as prediction
                output_heatmap = output_heatmap[-1]
                output_mask = output_mask[-1]

                # evaluation metric
                heatmap_iou = intersectionOverUnion(output_heatmap.cpu(),
                                                    target_heatmap_now.cpu(),
                                                    idx)  # experiemnt
                heatmap_iou_list.append(heatmap_iou)
                mask_iou = intersectionOverUnion(output_mask.cpu(),
                                                 target_mask_now.cpu(),
                                                 idx,
                                                 return_list=True)
                mask_iou_list.append((sum(mask_iou) / len(mask_iou))[0])

                round_output_label = torch.round(output_label).float()
                label_acc = float((round_output_label
                                   == target_label_now).sum()) / batch_size
                label_acc_list.append(label_acc)

                score_map_mask = output_mask.cpu()
                # score_map_mask = output_heatmap.cpu()
                # print((sum(mask_iou) / len(mask_iou))[0])

                #########################
                # final evuation accuracy
                import numpy as np
                temp_1 = (round_output_label == 1) & (
                    target_label_now == 1)  # positve label correct
                temp_acc_1 = temp_1.cpu().numpy()
                temp_2 = (round_output_label == 0) & (
                    target_label_now == 0)  # negative label correct
                temp_acc_2 = temp_2.cpu().numpy()

                final_pred_1 = np.logical_and(
                    temp_acc_1,
                    mask_iou > 0.5)  # positve label correct + iou > 50%
                final_pred_2 = temp_acc_2  # negative label correct
                final_acc = np.logical_or(final_pred_1, final_pred_2)
                final_acc_list.append(np.sum(final_acc) / batch_size)

                # for statistic
                temp_1 = (target_label_now == 1).cpu().numpy()
                temp_2 = (target_label_now == 0).cpu().numpy()
                gt_true_list.append(np.sum(temp_1) / batch_size)
                gt_false_list.append(np.sum(temp_2) / batch_size)

                pred_true_list.append(np.sum(final_pred_1) / batch_size)
                pred_false_list.append(np.sum(final_pred_2) / batch_size)
                pred_true_first_list.append(np.sum(temp_acc_1) / batch_size)
                ###############################

                if RELABEL:
                    # save in same checkpoint
                    raw_mask_path = meta['mask_path_list'][j][0]
                    img_index = meta['image_index_list'][j][0]
                    temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                    temp_tail = ('/').join(raw_mask_path.split('/')[-6:])
                    temp = os.path.join(temp_head, 'code/train_two_steps',
                                        checkpoint, 'pred_vis', temp_tail)
                    relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                    relabel_mask_dir = os.path.dirname(relabel_mask_dir)

                    raw_mask_rgb_path = os.path.join(
                        os.path.dirname(os.path.dirname(raw_mask_path)),
                        'first_mask_rgb', relabel_mask_name)
                    new_mask_rgb_path = os.path.join(relabel_mask_dir,
                                                     'gt_' + relabel_mask_name)
                    raw_rgb_frame_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'raw_frames', \
                        relabel_mask_name[:-4] + '.png')

                    # print(relabel_mask_dir)
                    # print(relabel_mask_name)
                    from PIL import Image
                    import numpy as np
                    if os.path.exists(raw_mask_rgb_path):
                        gt_mask_rgb = np.array(Image.open(raw_mask_rgb_path))
                    else:
                        gt_mask_rgb = np.array(Image.open(raw_rgb_frame_path))
                    # print(input_now.shape)
                    # print(score_map.shape)
                    pred_batch_img, pred_mask = relabel_heatmap(
                        input_now.cpu(), score_map_mask,
                        'pred')  # return an Image object

                    if not isdir(relabel_mask_dir):
                        mkdir_p(relabel_mask_dir)

                    gt_label_str = None
                    pred_label_str = None
                    from PIL import Image
                    pred_Image = Image.fromarray(pred_batch_img)

                    if target_label_now[0][0] == 0:
                        gt_label_str = "GT : False"
                    elif target_label_now[0][0] == 1:
                        gt_label_str = "GT : True"

                    # print(gt_label_str)

                    if round_output_label[0][0] == 0:
                        pred_label_str = "Pred : False"
                    elif round_output_label[0][0] == 1:
                        pred_label_str = "Pred : True"

                    if not gt_win or not pred_win:
                        ax1 = plt.subplot(121)
                        ax1.title.set_text(gt_label_str)
                        gt_win = plt.imshow(gt_mask_rgb)
                        ax2 = plt.subplot(122)
                        ax2.title.set_text(pred_label_str)
                        pred_win = plt.imshow(pred_batch_img)
                    else:
                        gt_win.set_data(gt_mask_rgb)
                        pred_win.set_data(pred_batch_img)
                        ax1.title.set_text(gt_label_str)
                        ax2.title.set_text(pred_label_str)

                    plt.plot()
                    index_name = "%05d.jpg" % (img_index)
                    plt.savefig(
                        os.path.join(relabel_mask_dir, 'vis_' + index_name))
                    pred_mask.save(os.path.join(relabel_mask_dir, index_name))
                    pred_Image.save(
                        os.path.join(relabel_mask_dir, 'image_' + index_name))
                    # print(os.path.join(relabel_mask_dir, index_name))

            # record final acc
            final_acces.update(
                sum(final_acc_list) / len(final_acc_list), input.size(0))

            # for statistic
            gt_trues.update(
                sum(gt_true_list) / len(gt_true_list), input.size(0))
            gt_falses.update(
                sum(gt_false_list) / len(gt_false_list), input.size(0))
            pred_trues.update(
                sum(pred_true_list) / len(pred_true_list), input.size(0))
            pred_falses.update(
                sum(pred_false_list) / len(pred_false_list), input.size(0))
            pred_trues_first.update(
                sum(pred_true_first_list) / len(pred_true_first_list),
                input.size(0))

            # record loss
            total_losses.update(total_loss.item(), input.size(0))
            heatmap_losses.update(heatmap_loss.item(), input.size(0))
            mask_losses.update(mask_loss.item(), input.size(0))
            label_losses.update(label_loss.item(), input.size(0))

            # record metric
            heatmap_ioues.update(
                sum(heatmap_iou_list) / len(heatmap_iou_list), input.size(0))
            mask_ioues.update(
                sum(mask_iou_list) / len(mask_iou_list), input.size(0))
            label_acces.update(
                sum(label_acc_list) / len(label_acc_list), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f}'.format(
                batch=i + 1,
                size=len(val_loader),
                data=data_time.val,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=total_losses.avg)
            bar.next()
        bar.finish()

    print(heatmap_losses.avg, heatmap_ioues.avg, \
        mask_losses.avg, mask_ioues.avg, \
        label_losses.avg, label_acces.avg, \
        final_acces.avg)
    return heatmap_losses.avg, heatmap_ioues.avg, \
        mask_losses.avg, mask_ioues.avg, \
        label_losses.avg, label_acces.avg, \
        final_acces.avg
예제 #7
0
def validate(val_loader,
             model,
             criterion,
             num_classes,
             checkpoint,
             debug=False,
             flip=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    ioues = AverageMeter()

    # predictions
    predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)

    # switch to evaluate mode
    model.eval()

    gt_win, pred_win = None, None
    iou = None
    end = time.time()
    bar = Bar('Eval ', max=len(val_loader))
    with torch.no_grad():
        for i, (input, input_depth, target, meta) in enumerate(val_loader):
            # if RELABEL and i == 10 : break

            # measure data loading time
            data_time.update(time.time() - end)

            input = input.to(device, non_blocking=True)
            input_depth = input_depth.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            batch_size = input.shape[0]
            loss = 0
            last_state = None
            iou_list = []

            # compute use TSM feature
            for j in range(6):
                input_now = input[:, j]  # [B, 3, 256, 256]
                input_depth_now = input_depth[:, j]
                target_now = target[:, j]
                if j == 0:
                    output, last_state = model(
                        torch.cat((input_now, input_depth_now), 1))
                else:
                    output, _ = model(torch.cat((input_now, input_depth_now),
                                                1),
                                      input_last_state=last_state)
                    # print(output.shape)

                if type(
                        output
                ) == list:  # multiple output # beacuse of intermediate prediction
                    for o in output:
                        loss += criterion(o, target_now)
                    output = output[-1]
                else:  # single output
                    pass

                temp_iou = intersectionOverUnion(output.cpu(),
                                                 target_now.cpu(),
                                                 idx)  # have not tested
                iou_list.append(temp_iou)
                score_map = output[-1].cpu() if type(
                    output) == list else output.cpu()

                if RELABEL:
                    # save in same checkpoint
                    raw_mask_path = meta['mask_path_list'][j][0]
                    img_index = meta['image_index_list'][j][0]
                    temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                    temp_tail = ('/').join(raw_mask_path.split('/')[-5:])
                    temp = os.path.join(temp_head, 'code/train_two_steps',
                                        checkpoint, 'pred_vis', temp_tail)
                    relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                    relabel_mask_dir = os.path.dirname(relabel_mask_dir)

                    raw_mask_rgb_path = os.path.join(
                        os.path.dirname(os.path.dirname(raw_mask_path)),
                        'first_mask_rgb', relabel_mask_name)
                    new_mask_rgb_path = os.path.join(relabel_mask_dir,
                                                     'gt_' + relabel_mask_name)
                    raw_rgb_frame_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'raw_frames', \
                        relabel_mask_name[:-4] + '.png')

                    # print(relabel_mask_dir)
                    # print(relabel_mask_name)
                    from PIL import Image
                    import numpy as np
                    if os.path.exists(raw_mask_rgb_path):
                        gt_mask_rgb = np.array(Image.open(raw_mask_rgb_path))
                    else:
                        gt_mask_rgb = np.array(Image.open(raw_rgb_frame_path))

                    # pred_batch_img, pred_mask = relabel_heatmap(input_now, score_map, 'pred') # old
                    _, pred_mask = relabel_heatmap(input_now, score_map,
                                                   'pred')

                    # preprocess
                    temp = input_now[0].cpu().numpy() * 255
                    temp_input = np.zeros(
                        (temp.shape[1], temp.shape[2], temp.shape[0]))
                    for _i in range(3):
                        temp_input[:, :, _i] = temp[_i, :, :]
                    temp_input = np.asarray(temp_input, np.uint8)
                    temp_output = score_map.cpu().numpy() * 255
                    temp_output = np.asarray(temp_output, np.uint8)
                    temp_output = np.reshape(temp_output, (64, 64))

                    pred_batch_img = eval_heatmap(
                        temp_input, temp_output)  # return correct mask + image

                    if not isdir(relabel_mask_dir):
                        mkdir_p(relabel_mask_dir)

                    if not gt_win or not pred_win:
                        ax1 = plt.subplot(121)
                        ax1.title.set_text('MASK_RGB_GT')
                        gt_win = plt.imshow(gt_mask_rgb)
                        ax2 = plt.subplot(122)
                        ax2.title.set_text('Mask_RGB_PRED')
                        pred_win = plt.imshow(pred_batch_img)
                    else:
                        gt_win.set_data(gt_mask_rgb)
                        pred_win.set_data(pred_batch_img)
                    plt.plot()
                    index_name = "%05d.jpg" % (img_index)
                    plt.savefig(
                        os.path.join(relabel_mask_dir, 'vis_' + index_name))
                    pred_mask.save(os.path.join(relabel_mask_dir, index_name))

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))
            # acces.update(acc[0], input.size(0))
            ioues.update(sum(iou_list) / len(iou_list), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f}'.format(
                batch=i + 1,
                size=len(val_loader),
                data=data_time.val,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg)
            bar.next()
        bar.finish()
    return losses.avg, ioues.avg, predictions
def main(args):
    global best_iou
    global idx
    global output_res
    output_res = args.out_res

    # 2020.3.2
    global REDRAW

    # 2020.3.4
    # if you do type arg.resume
    # args.checkpoint would be derived from arg.resume
    if args.resume != '':
        args.checkpoint = ('/').join(args.resume.split('/')[:2])

    if args.relabel == True:
        args.test_batch = 1
    elif args.test == True:
        # args.train_batch = 4
        # args.test_batch = 4
        # args.epochs = 20
        args.train_batch = 2
        args.test_batch = 2
        args.epochs = 10

    # write line-chart and stop program
    if args.write:
        draw_line_chart(args, os.path.join(args.checkpoint, 'log.txt'))
        return

    # idx is the index of joints used to compute accuracy
    if args.dataset in ['mpii', 'lsp']:
        idx = [1, 2, 3, 4, 5, 6, 11, 12, 15, 16]
    elif args.dataset == 'coco':
        idx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    elif args.dataset == 'sad' or args.dataset == 'sad_step_2' or args.dataset == 'sad_step_2_eval':
        idx = [1]  # support affordance
    else:
        print("Unknown dataset: {}".format(args.dataset))
        assert False

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    njoints = datasets.__dict__[args.dataset].njoints

    model = models.__dict__[args.arch](num_stacks=args.stacks,
                                       num_blocks=args.blocks,
                                       num_classes=njoints,
                                       resnet_layers=args.resnet_layers)

    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    criterion = losses.BCELoss().to(device)

    if args.solver == 'rms':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    elif args.solver == 'adam':
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=args.lr,
        )
    else:
        print('Unknown solver: {}'.format(args.solver))
        assert False

    # optionally resume from a checkpoint
    title = args.dataset + ' ' + args.arch
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Val Acc'])

    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # create data loader
    train_dataset = datasets.__dict__[args.dataset](
        is_train=True,
        **vars(args))  #-> depend on args.dataset to replace with datasets
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.train_batch,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    val_dataset = datasets.__dict__[args.dataset](is_train=False, **vars(args))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.test_batch,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # redraw training / test label :
    global RELABEL
    if args.relabel:
        RELABEL = True
        if args.evaluate:
            print('\nRelabel val label')
            loss, acc = validate(val_loader, model, criterion, njoints,
                                 args.checkpoint, args.debug, args.flip)
            print("Val acc: %.3f" % (acc))
            return

    # evaluation only
    global JUST_EVALUATE
    JUST_EVALUATE = False
    if args.evaluate:
        print('\nEvaluation only')
        JUST_EVALUATE = True
        loss, acc = validate(val_loader, model, criterion, njoints,
                             args.checkpoint, args.debug, args.flip)
        print("Val acc: %.3f" % (acc))
        return
def validate(val_loader,
             model,
             criterion,
             num_classes,
             checkpoint,
             debug=False,
             flip=True):
    import numpy as np

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    acces = AverageMeter()
    ioues = AverageMeter()

    # for statistic
    gt_trues = AverageMeter()
    gt_falses = AverageMeter()
    pred_trues = AverageMeter()  # true == true and iou > 50%
    pred_falses = AverageMeter()

    pred_trues_first = AverageMeter()  # true == true

    # iou > 50% and step 2 labels are both right -> correcct
    # if label is false (and pred is false too) -> correct
    final_acces = AverageMeter()

    # switch to evaluate mode
    model.eval()

    gt_win, pred_win = None, None
    iou = None
    end = time.time()
    bar = Bar('Eval ', max=len(val_loader))
    with torch.no_grad():
        for i, (input, input_depth, input_mask, target, meta,
                gt_mask) in enumerate(val_loader):

            # if i == 10 : break

            # measure data loading time
            data_time.update(time.time() - end)

            input, input_mask, target = input.to(device), input_mask.to(
                device), target.to(device, non_blocking=True)
            input_depth = input_depth.to(device)

            batch_size = input.shape[0]
            loss = 0
            last_state = None
            acc_list = []
            iou_list = []
            final_acc_list = []

            # for statistic
            gt_true_list = []
            gt_false_list = []
            pred_true_list = []
            pred_false_list = []
            pred_true_first_list = []

            for j in range(6):
                input_now = input[:, j]  # [B, 3, 256, 256]
                input_depth_now = input_depth[:, j]
                input_mask_now = input_mask[:, j]
                gt_mask_now = gt_mask[:, j]
                target_now = target[:, j]

                if j == 0:
                    output, output_state = model(
                        torch.cat((input_now, input_depth_now, input_mask_now),
                                  1))
                else:
                    output, output_state = model(torch.cat(
                        (input_now, input_depth_now, input_mask_now), 1),
                                                 input_state=last_state)
                    # print(output.shape)

                last_state = output_state

                #############################
                '''
                testing now
                '''
                input_mask_now = input_mask_now.cpu()
                gt_mask_now = gt_mask_now.cpu()

                raw_mask_path = meta['mask_path_list'][j][0]
                img_index = meta['image_index_list'][j][0]
                temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                temp_tail = ('/').join(raw_mask_path.split('/')[-6:])
                temp = os.path.join(temp_head, 'code/train_two_steps',
                                    checkpoint, 'pred_vis', temp_tail)
                relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                relabel_mask_dir = os.path.dirname(relabel_mask_dir)
                area_head = '/home/s5078345/Affordance-Detection-on-Video/faster-rcnn.pytorch/data_affordance_bbox'
                area_tail = ('/').join(raw_mask_path.split('/')[3:-1])

                # print(raw_mask_path)
                # print(area_tail)

                area_to_detect_data_path = os.path.join(
                    area_head, area_tail, relabel_mask_name[:-4] + '.txt')
                area_to_detect_list = []
                with open(area_to_detect_data_path) as f:
                    for line in f:
                        inner_list = [
                            int(elt.strip()) for elt in line.split(' ')
                        ]
                        # in alternative, if you need to use the file content as numbers
                        # inner_list = [int(elt.strip()) for elt in line.split(',')]
                        area_to_detect_list.append(inner_list)
                if len(area_to_detect_list) == 0:
                    area_to_detect_list = None

                area_to_detect = area_to_detect_list
                if area_to_detect is not None:
                    out_resized_area = torch.zeros((1, 1, 64, 64))
                    gt_resized_area = torch.zeros((1, 1, 64, 64))
                    for _i in range(len(area_to_detect)):
                        x_min, y_min, x_max, y_max = area_to_detect[_i]

                        x_min = math.floor(x_min / 640 * 64)
                        y_min = math.floor(y_min / 480 * 64)
                        x_max = math.ceil(x_max / 640 * 64)
                        y_max = math.ceil(y_max / 480 * 64)

                        # clip pred
                        out_resized_area[0, 0, y_min:y_max,
                                         x_min:x_max] = input_mask_now[
                                             0, 0, y_min:y_max, x_min:x_max]

                        # clip GT
                        gt_resized_area[0, 0, y_min:y_max,
                                        x_min:x_max] = gt_mask_now[0, 0,
                                                                   y_min:y_max,
                                                                   x_min:x_max]

                    input_mask_now = out_resized_area
                    gt_mask_now = gt_resized_area

                ####################################################

                # compute loss
                round_output = torch.round(output).float()
                loss += criterion(output, target_now)

                temp_acc = float(
                    (round_output == target_now).sum()) / batch_size

                temp_1 = (round_output == 1) & (target_now == 1)
                temp_acc_1 = temp_1.cpu().numpy()
                temp_2 = (round_output == 0) & (target_now == 0)
                temp_acc_2 = temp_2.cpu().numpy()

                temp_iou = intersectionOverUnion(gt_mask_now,
                                                 input_mask_now.cpu(),
                                                 idx,
                                                 return_list=True)

                final_pred_1 = np.logical_and(temp_acc_1, temp_iou > 0.5)
                final_pred_2 = temp_acc_2
                final_pred = np.logical_or(final_pred_1, final_pred_2)

                acc_list.append(temp_acc)
                final_acc_list.append(np.sum(final_pred) / batch_size)
                round_output = round_output.cpu()

                # for statistic
                temp_1 = (target_now == 1).cpu().numpy()
                temp_2 = (target_now == 0).cpu().numpy()
                gt_true_list.append(np.sum(temp_1) / batch_size)
                gt_false_list.append(np.sum(temp_2) / batch_size)

                pred_true_list.append(np.sum(final_pred_1) / batch_size)
                pred_false_list.append(np.sum(final_pred_2) / batch_size)
                pred_true_first_list.append(np.sum(temp_acc_1) / batch_size)

                if RELABEL:
                    '''
                    left image : GT
                        image : gt_mask_rgb
                        label : target_now
                    right image : predict result
                        image : raw_frames (pred false) or ./checkpoint_0428/pred_vis (pred true)
                        label : round_output
                    '''
                    from PIL import Image
                    import numpy as np
                    import copy

                    # save in same checkpoint
                    img_index = meta['image_index_list'][j][0]

                    raw_mask_path = meta['mask_path_list'][j][0]
                    gt_mask_path = meta['gt_mask_path_list'][j][0]

                    temp_head = ('/').join(gt_mask_path.split('/')[:-8])
                    temp_tail = ('/').join(gt_mask_path.split('/')[-5:])
                    temp = os.path.join(temp_head,
                                        'code/train_two_steps/eval_bbox_clip',
                                        'pred_vis', temp_tail)
                    relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                    relabel_mask_dir = os.path.dirname(
                        relabel_mask_dir)  # new dir name for pred_vis

                    # raw frame
                    raw_rgb_frame_path = os.path.join(
                        os.path.dirname(os.path.dirname(gt_mask_path)),
                        'raw_frames',
                        gt_mask_path.split('/')[-1][:-4] + '.png')
                    raw_frame = np.array(Image.open(raw_rgb_frame_path))

                    # gt_mask_rgb
                    gt_mask_rgb_path = os.path.join(
                        os.path.dirname(os.path.dirname(gt_mask_path)),
                        'mask_rgb',
                        gt_mask_path.split('/')[-1])
                    if os.path.exists(gt_mask_rgb_path):
                        gt_mask_rgb = np.array(Image.open(gt_mask_rgb_path))
                    else:
                        gt_mask_rgb = copy.deepcopy(raw_frame)

                    # pred mask
                    pred_mask_path = os.path.join(
                        os.path.dirname(raw_mask_path), relabel_mask_name)
                    pred_mask = np.array(Image.open(pred_mask_path))

                    pred_mask_rgb = eval_heatmap(raw_frame,
                                                 pred_mask)  # generate rgb

                    if not isdir(relabel_mask_dir):
                        mkdir_p(relabel_mask_dir)

                    gt_label_str = None
                    pred_label_str = None
                    gt_output = gt_mask_rgb
                    pred_output = None

                    if target_now[0][0] == 0:
                        gt_label_str = "GT : False"
                    elif target_now[0][0] == 1:
                        gt_label_str = "GT : True"

                    if round_output[0][0] == 0:
                        pred_label_str = "Pred : False"
                        pred_output = raw_frame
                    elif round_output[0][0] == 1:
                        pred_output = pred_mask_rgb
                        if target_now[0][0] == 0:
                            pred_label_str = "Pred : True"
                        elif target_now[0][0] == 1 and temp_iou > 0.5:
                            pred_label_str = "Pred : True (IoU : O)"
                        elif target_now[0][0] == 1 and temp_iou <= 0.5:
                            pred_label_str = "Pred : True (IoU : X)"

                    # output_str = gt_label_str + '. ' + pred_label_str

                    if not gt_win or not pred_win:
                        ax1 = plt.subplot(121)
                        ax1.title.set_text(gt_label_str)
                        gt_win = plt.imshow(gt_output)
                        ax2 = plt.subplot(122)
                        ax2.title.set_text(pred_label_str)
                        pred_win = plt.imshow(pred_output)

                    else:
                        gt_win.set_data(gt_output)
                        pred_win.set_data(pred_output)

                        ax1.title.set_text(gt_label_str)
                        ax2.title.set_text(pred_label_str)

                    plt.plot()
                    index_name = "%05d.jpg" % (img_index)
                    plt.savefig(
                        os.path.join(relabel_mask_dir, 'vis_' + index_name))

            # measure accuracy and record loss
            losses.update(loss.item(), input.size(0))
            acces.update(sum(acc_list) / len(acc_list), input.size(0))
            final_acces.update(
                sum(final_acc_list) / len(final_acc_list), input.size(0))

            # for statistic
            gt_trues.update(
                sum(gt_true_list) / len(gt_true_list), input.size(0))
            gt_falses.update(
                sum(gt_false_list) / len(gt_false_list), input.size(0))
            pred_trues.update(
                sum(pred_true_list) / len(pred_true_list), input.size(0))
            pred_falses.update(
                sum(pred_false_list) / len(pred_false_list), input.size(0))
            pred_trues_first.update(
                sum(pred_true_first_list) / len(pred_true_first_list),
                input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f}'.format(
                batch=i + 1,
                size=len(val_loader),
                data=data_time.val,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td,
                loss=losses.avg)
            bar.next()
        bar.finish()

    # for statistic
    print("GT true : %.3f" % (gt_trues.avg))
    print("GT false : %.3f" % (gt_falses.avg))
    print("Pred true : %.3f" % (pred_trues.avg))
    print("Pred false : %.3f" % (pred_falses.avg))
    print("====")
    print("Pred true (no considering IoU) : %.3f" % (pred_trues_first.avg))
    print("IoU > 50 percent accuracy : %.3f" %
          (pred_trues.avg / pred_trues_first.avg))
    print("===")
    print("True part : %.3f acc, False part : %.3f acc" %
          (pred_trues.avg / gt_trues.avg, pred_falses.avg / gt_falses.avg))
    print("Predict true label correct : %.3f" %
          (pred_trues_first.avg / gt_trues.avg))

    return losses.avg, acces.avg, final_acces.avg
def validate_v2(val_loader,
                model,
                criterion,
                num_classes,
                checkpoint,
                debug=False,
                flip=True):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    ioues = AverageMeter()

    # predictions
    predictions = torch.Tensor(val_loader.dataset.__len__(), num_classes, 2)

    # switch to evaluate mode
    model.eval()

    gt_win, pred_win = None, None
    iou = None
    end = time.time()
    bar = Bar('Eval ', max=len(val_loader))
    with torch.no_grad():
        for i, (input, input_depth, target, meta, video_input_eval, _,
                video_target_eval,
                area_to_detect_list) in enumerate(val_loader):
            # if i == 1 : break

            # if i != 5 : continue
            # if i == 6 : break
            # print()

            # measure data loading time
            data_time.update(time.time() - end)

            input = input.to(device, non_blocking=True)
            input_depth = input_depth.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            video_target_eval = video_target_eval.to(device, non_blocking=True)

            batch_size = input.shape[0]
            loss = 0
            iou_list = []

            for j in range(6):
                # load whole image
                input_whole = video_input_eval[:, j]
                target_whole = video_target_eval[:, j]  #[1, 32, 32]
                output_whole = torch.zeros((1, 32, 32))
                # output_whole = output_whole.to(device, non_blocking=True)

                for obj_i in range(3):
                    if len(area_to_detect_list[j]) <= obj_i:
                        continue

                    input_now = input[:, j, obj_i]  # [B, 3, 256, 256]
                    input_depth_now = input_depth[:, j, obj_i]
                    target_now = target[:, j, obj_i]
                    output, _ = model(
                        torch.cat((input_now, input_depth_now), 1))

                    output = output[-1]
                    score_map_part = output.cpu()
                    score_map_gt = target_now.cpu()
                    output = output[-1].cpu().numpy()

                    # map back to area_detect
                    _x_min = float(area_to_detect_list[j][obj_i][0].numpy())
                    _y_min = float(area_to_detect_list[j][obj_i][1].numpy())
                    _x_max = float(area_to_detect_list[j][obj_i][2].numpy())
                    _y_max = float(area_to_detect_list[j][obj_i][3].numpy())
                    OUT_RES = 32
                    # order is reversed
                    # print(_x_min, _y_min, _x_max, _y_max)
                    x_min = round(_x_min / 640 * OUT_RES)
                    y_min = round(_y_min / 480 * OUT_RES)
                    x_max = min(round(_x_max / 640 * OUT_RES), 31)
                    y_max = min(round(_y_max / 480 * OUT_RES), 31)
                    x_len = x_max - x_min
                    y_len = y_max - y_min
                    # print(x_min, y_min, x_max, y_max)

                    ## resize out now to fixed size
                    if x_len > 0 and y_len > 0:
                        resized_out = faster_rcnn_crop(output, x_len, y_len)
                        # print(resized_out.shape)
                        # output_whole[0, y_min:y_max, x_min:x_max] = 1
                        output_whole[0, y_min:y_max, x_min:x_max] = resized_out

                        # resized_out = faster_rcnn_crop(output, y_len, x_len)
                        # output_whole[0, x_min:x_max, y_min:y_max] = resized_out

                        # # part image
                        # if RELABEL:
                        #     # save in same checkpoint
                        #     raw_mask_path = meta['mask_path_list'][j][0]
                        #     img_index = meta['image_index_list'][j][0]
                        #     temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                        #     temp_tail = ('/').join(raw_mask_path.split('/')[-6:])
                        #     temp = os.path.join(temp_head, 'code/train_two_steps', checkpoint, 'pred_vis', temp_tail)
                        #     relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                        #     relabel_mask_dir = os.path.dirname(relabel_mask_dir)

                        #     raw_mask_rgb_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'first_mask_rgb', relabel_mask_name)
                        #     new_mask_rgb_path = os.path.join(relabel_mask_dir, 'gt_' + relabel_mask_name)
                        #     raw_rgb_frame_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'raw_frames', \
                        #         relabel_mask_name[:-4] + '.png')

                        #     from PIL import Image
                        #     import numpy as np
                        #     # if os.path.exists(raw_mask_rgb_path):
                        #     #     gt_mask_rgb = np.array(Image.open(raw_mask_rgb_path))
                        #     # else :
                        #     #     gt_mask_rgb = np.array(Image.open(raw_rgb_frame_path))

                        #     gt_mask_rgb, _ = relabel_heatmap(input_now, score_map_gt, 'gt') # checked

                        #     pred_batch_img, pred_mask = relabel_heatmap(input_now, score_map_part, 'pred') # return an Image object

                        #     if not isdir(relabel_mask_dir):
                        #         mkdir_p(relabel_mask_dir)

                        #     if not gt_win or not pred_win:
                        #         ax1 = plt.subplot(121)
                        #         ax1.title.set_text('MASK_RGB_GT')
                        #         gt_win = plt.imshow(gt_mask_rgb)
                        #         ax2 = plt.subplot(122)
                        #         ax2.title.set_text('Mask_RGB_PRED')
                        #         pred_win = plt.imshow(pred_batch_img)
                        #     else:
                        #         gt_win.set_data(gt_mask_rgb)
                        #         pred_win.set_data(pred_batch_img)
                        #     plt.plot()

                        #     plt.savefig(os.path.join(relabel_mask_dir, '%05d_part_%d.jpg' % (img_index, obj_i)))
                        #     # pred_mask.save(os.path.join(relabel_mask_dir, index_name))

                output_whole = torch.unsqueeze(output_whole, 0)
                target_whole = torch.unsqueeze(target_whole, 0)

                score_map_whole = output_whole.cpu()

                temp_iou = intersectionOverUnion(output_whole,
                                                 target_whole.cpu(),
                                                 idx)  # have not tested
                iou_list.append(temp_iou)

                # whole image
                if RELABEL:
                    # save in same checkpoint
                    raw_mask_path = meta['mask_path_list'][j][0]
                    img_index = meta['image_index_list'][j][0]
                    temp_head = ('/').join(raw_mask_path.split('/')[:-8])
                    temp_tail = ('/').join(raw_mask_path.split('/')[-6:])
                    temp = os.path.join(temp_head, 'code/train_two_steps',
                                        checkpoint, 'pred_vis', temp_tail)
                    relabel_mask_dir, relabel_mask_name = os.path.split(temp)
                    relabel_mask_dir = os.path.dirname(relabel_mask_dir)

                    raw_mask_rgb_path = os.path.join(
                        os.path.dirname(os.path.dirname(raw_mask_path)),
                        'first_mask_rgb', relabel_mask_name)
                    new_mask_rgb_path = os.path.join(relabel_mask_dir,
                                                     'gt_' + relabel_mask_name)
                    raw_rgb_frame_path = os.path.join(os.path.dirname(os.path.dirname(raw_mask_path)), 'raw_frames', \
                        relabel_mask_name[:-4] + '.png')

                    from PIL import Image
                    import numpy as np
                    if os.path.exists(raw_mask_rgb_path):
                        gt_mask_rgb = np.array(Image.open(raw_mask_rgb_path))
                    else:
                        gt_mask_rgb = np.array(Image.open(raw_rgb_frame_path))
                    # print(input_now.shape)
                    # print(score_map.shape)
                    pred_batch_img, pred_mask = relabel_heatmap(
                        input_whole, score_map_whole,
                        'pred')  # return an Image object

                    if not isdir(relabel_mask_dir):
                        mkdir_p(relabel_mask_dir)

                    if not gt_win or not pred_win:
                        ax1 = plt.subplot(121)
                        ax1.title.set_text('MASK_RGB_GT')
                        gt_win = plt.imshow(gt_mask_rgb)
                        ax2 = plt.subplot(122)
                        ax2.title.set_text('Mask_RGB_PRED')
                        pred_win = plt.imshow(pred_batch_img)
                    else:
                        gt_win.set_data(gt_mask_rgb)
                        pred_win.set_data(pred_batch_img)
                    plt.plot()
                    index_name = "%05d.jpg" % (img_index)
                    plt.savefig(
                        os.path.join(relabel_mask_dir, 'vis_' + index_name))
                    pred_mask.save(os.path.join(relabel_mask_dir, index_name))
                    # print(os.path.join(relabel_mask_dir, 'vis_' + index_name))

            # acces.update(acc[0], input.size(0))
            ioues.update(sum(iou_list) / len(iou_list), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix = '({batch}/{size}) Data: {data:.6f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:}'.format(
                batch=i + 1,
                size=len(val_loader),
                data=data_time.val,
                bt=batch_time.avg,
                total=bar.elapsed_td,
                eta=bar.eta_td)
            bar.next()
        bar.finish()
    return 0, ioues.avg, predictions