Beispiel #1
0
def test(testloader, model, criterion, epoch, use_cuda):
    model.eval()
    accs = np.ones((len(testloader))) * -1000.0
    losses = np.ones((len(testloader))) * -1000.0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            inputs, targets = torch.autograd.Variable(
                inputs), torch.autograd.Variable(targets)
            outputs = model(inputs)
            losses[batch_idx] = criterion(outputs, targets).item()
            accs[batch_idx] = evaluate.accuracy(outputs.data,
                                                targets.data,
                                                topk=(1, ))[0].item()
    return (np.average(losses), np.average(accs))
Beispiel #2
0
def train(trainloader, model, criterion, optimizer, epoch, use_cuda):
    model.train()
    accs = np.ones((len(trainloader))) * -1000.0
    losses = np.ones((len(trainloader))) * -1000.0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = torch.autograd.Variable(
            inputs), torch.autograd.Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        losses[batch_idx] = loss.item()
        accs[batch_idx] = evaluate.accuracy(outputs.data,
                                            targets.data)[0].item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return (np.average(losses), np.average(accs))
Beispiel #3
0
def val(args):
    with open(args.test_box, 'r') as f:
        gt_boxes = [
            map(float,
                x.strip().split(' ')[2:]) for x in f.readlines()
        ]
    gt_boxes = [(box[0], box[1], box[0] + box[2] - 1, box[1] + box[3] - 1)
                for box in gt_boxes]

    # meters
    top1_clsacc = AverageMeter()
    top1_locerr = AverageMeter()
    top5_clsacc = AverageMeter()
    top5_locerr = AverageMeter()
    top1_clsacc.reset()
    top1_locerr.reset()
    top5_clsacc.reset()
    top5_locerr.reset()

    # get model
    model = get_model(args)
    model.eval()

    # get data
    _, valcls_loader, valloc_loader = data_loader(args, test_path=True)
    assert len(valcls_loader) == len(valloc_loader), \
        'Error! Different size for two dataset: loc({}), cls({})'.format(len(valloc_loader), len(valcls_loader))

    # testing
    DEBUG = True
    if DEBUG:
        # show_idxs = np.arange(20)
        np.random.seed(2333)
        show_idxs = np.arange(len(valcls_loader))
        np.random.shuffle(show_idxs)
        show_idxs = show_idxs[:20]

    # evaluation classification task
    pred_prob1 = []
    pred_prob2 = []
    pred_prob3 = []
    for dat in tqdm(valcls_loader):
        # parse data
        img_path, img, label_in = dat
        if args.tencrop == 'True':
            bs, ncrops, c, h, w = img.size()
            img = img.view(-1, c, h, w)
            label_input = label_in.repeat(10, 1)
            label = label_input.view(-1)
        else:
            label = label_in

        # forward pass
        img, label = img.cuda(), label.cuda()
        img_var, label_var = Variable(img), Variable(label)
        logits = model(img_var)

        # get classification prob
        logits0 = logits[-1]
        logits0 = F.softmax(logits0, dim=1)
        if args.tencrop == 'True':
            logits0 = logits0.view(1, ncrops, -1).mean(1)
        pred_prob3.append(logits0.cpu().data.numpy())

        logits1 = logits[-2]
        logits1 = F.softmax(logits1, dim=1)
        if args.tencrop == 'True':
            logits1 = logits1.view(1, ncrops, -1).mean(1)
        pred_prob2.append(logits1.cpu().data.numpy())

        logits2 = logits[-3]
        logits2 = F.softmax(logits2, dim=1)
        if args.tencrop == 'True':
            logits2 = logits2.view(1, ncrops, -1).mean(1)
        pred_prob1.append(logits2.cpu().data.numpy())
        # update result record
        prec1_1, prec5_1 = evaluate.accuracy(logits0.cpu().data,
                                             label_in.long(),
                                             topk=(1, 5))
        top1_clsacc.update(prec1_1[0].numpy(), img.size()[0])
        top5_clsacc.update(prec5_1[0].numpy(), img.size()[0])

    pred_prob1 = np.concatenate(pred_prob1, axis=0)
    pred_prob2 = np.concatenate(pred_prob2, axis=0)
    pred_prob3 = np.concatenate(pred_prob3, axis=0)
    print('== cls err')
    print('Top1: {:.2f} Top5: {:.2f}\n'.format(100.0 - top1_clsacc.avg,
                                               100.0 - top5_clsacc.avg))

    thresholds = map(float, args.threshold.split(','))
    thresholds = list(thresholds)
    for th in thresholds:
        top1_locerr.reset()
        top5_locerr.reset()
        for idx, dat in tqdm(enumerate(valloc_loader)):
            # parse data
            img_path, img, label = dat

            # forward pass
            img, label = img.cuda(), label.cuda()
            img_var, label_var = Variable(img), Variable(label)
            logits = model(img_var)
            child_map = F.upsample(model.module.get_child_maps(),
                                   size=(28, 28),
                                   mode='bilinear',
                                   align_corners=True)
            child_map = child_map.cpu().data.numpy()
            parent_maps = F.upsample(model.module.get_parent_maps(),
                                     size=(28, 28),
                                     mode='bilinear',
                                     align_corners=True)
            parent_maps = parent_maps.cpu().data.numpy()
            root_maps = model.module.get_root_maps()
            root_maps = root_maps.cpu().data.numpy()
            top_boxes, top_maps = get_topk_boxes_hier(pred_prob3[idx, :],
                                                      pred_prob2[idx, :],
                                                      pred_prob1[idx, :],
                                                      child_map,
                                                      parent_maps,
                                                      root_maps,
                                                      img_path[0],
                                                      args.input_size,
                                                      args.crop_size,
                                                      topk=(1, 5),
                                                      threshold=th,
                                                      mode='union')
            top1_box, top5_boxes = top_boxes

            # update result record
            locerr_1, locerr_5 = evaluate.locerr(
                (top1_box, top5_boxes),
                label.cpu().data.long().numpy(),
                gt_boxes[idx],
                topk=(1, 5))
            top1_locerr.update(locerr_1, img.size()[0])
            top5_locerr.update(locerr_5, img.size()[0])
            if DEBUG:
                if idx in show_idxs:
                    save_im_heatmap_box(
                        img_path[0],
                        top_maps,
                        top5_boxes,
                        '../figs/',
                        gt_label=label.cpu().data.long().numpy(),
                        gt_box=gt_boxes[idx])
        print('=========== threshold: {} ==========='.format(th))
        print('== loc err')
        print('Top1: {:.2f} Top5: {:.2f}\n'.format(top1_locerr.avg,
                                                   top5_locerr.avg))
Beispiel #4
0
def train(args):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model, optimizer = get_model(args)
    model.train()
    train_loader, _, _ = data_loader(args)

    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        config = json.dumps(vars(args), indent=4, separators=(',', ':'))
        fw.write(config)
        fw.write('#epoch \t loss \t pred@1 \t pred@5\n')

    # construct writer
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    writer = SummaryWriter(log_dir=args.log_dir)

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch
    end = time.time()
    max_iter = total_epoch * len(train_loader)
    print('Max iter:', max_iter)
    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        top1.reset()
        top5.reset()

        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)

        if res:
            with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
                for g in optimizer.param_groups:
                    out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr'])
                    fw.write(out_str)

        steps_per_epoch = len(train_loader)
        for idx, dat in enumerate(train_loader):
            img_path, img, label = dat
            global_counter += 1
            img, label = img.cuda(), label[2].cuda()
            img_var,label3_var = Variable(img), Variable(label)

            logits = model(img_var)
            loss_val = model.module.get_loss(logits, label3_var)

            # write into tensorboard
            writer.add_scalar('loss_val', loss_val, global_counter)

            # network parameter update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            if not args.onehot == 'True':
                logits1 = torch.squeeze(logits)
                prec1, prec5 = evaluate.accuracy(logits1.data, label.long(), topk=(1, 5))
                top1.update(prec1[0], img.size()[0])
                top5.update(prec5[0], img.size()[0])


            losses.update(loss_val.data, img.size()[0])
            batch_time.update(time.time() - end)

            end = time.time()

            if global_counter % args.disp_interval == 0:
                # Calculate ETA
                eta_seconds = ((total_epoch - current_epoch) * steps_per_epoch +
                               (steps_per_epoch - idx)) * batch_time.avg
                eta_str = "{:0>8}".format(datetime.timedelta(seconds=int(eta_seconds)))
                eta_seconds_epoch = steps_per_epoch * batch_time.avg
                eta_str_epoch = "{:0>8}".format(datetime.timedelta(seconds=int(eta_seconds_epoch)))
                print('Epoch: [{0}][{1}/{2}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'ETA {eta_str}({eta_str_epoch})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                      'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                       current_epoch, global_counter % len(train_loader), len(train_loader), batch_time=batch_time,
                       eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, top1=top1, top5=top5))

        plotter.plot('Loss', 'train', current_epoch, losses.avg)
        plotter.plot('top1', 'train', current_epoch, top1.avg)
        plotter.plot('top5', 'train', current_epoch, top5.avg)



        current_epoch += 1
        if current_epoch % 10 == 0:
            save_checkpoint(args,
                            {
                                'epoch': current_epoch,
                                'arch': 'resnet',
                                'global_counter': global_counter,
                                'state_dict': model.state_dict(),
                                'optimizer': optimizer.state_dict()
                            }, is_best=False,
                            filename='%s_epoch_%d_glo_step_%d.pth.tar'
                                     % (args.dataset, current_epoch, global_counter))

        with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
            fw.write('%d \t %.4f \t %.3f \t %.3f\n' % (current_epoch, losses.avg, top1.avg, top5.avg))

        losses.reset()
        top1.reset()
        top5.reset()
Beispiel #5
0
    def validation(self, epoch):
        self.model.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        val_loss = 0.0

        AP = np.zeros(self.numClasses + 1)
        PCK = np.zeros(self.numClasses + 1)
        PCKh = np.zeros(self.numClasses + 1)
        count = np.zeros(self.numClasses + 1)

        cnt = 0
        for i, (input, heatmap, centermap, img_path) in enumerate(tbar):

            cnt += 1

            input_var = input.cuda()
            heatmap_var = heatmap.cuda()
            centermap_var = centermap.cuda()

            self.optimizer.zero_grad()

            heat = torch.zeros(self.numClasses + 1, 46, 46).cuda()
            cell = torch.zeros(15, 46, 46).cuda()
            hide = torch.zeros(15, 46, 46).cuda()

            losses = {}
            loss = 0

            start_model = time.time()
            for j in range(self.frame_memory):
                heat, cell, hide = self.model(input_var, centermap_var, j,
                                              heat, hide, cell)

                losses[j] = self.criterion(heat, heatmap_var[0:, j])

                loss += losses[j].item()

                acc, acc_PCK, acc_PCKh, cnt, pred, visible = evaluate.accuracy(heat.detach().cpu().numpy(),\
                                            heatmap_var[:,j].detach().cpu().numpy(),0.2,0.5, self.dataset)

                AP[0] = (AP[0] * (self.frame_memory * i + j) + acc[0]) / (
                    (self.frame_memory * i + j) + 1)
                PCK[0] = (PCK[0] *
                          (self.frame_memory * i + j) + acc_PCK[0]) / (
                              (self.frame_memory * i + j) + 1)
                PCKh[0] = (PCKh[0] *
                           (self.frame_memory * i + j) + acc_PCKh[0]) / (
                               (self.frame_memory * i + j) + 1)

                for k in range(self.numClasses + 1):
                    if visible[k] == 1:
                        AP[k] = (AP[k] * count[k] + acc[k]) / (count[k] + 1)
                        PCK[k] = (PCK[k] * count[k] + acc_PCK[k]) / (count[k] +
                                                                     1)
                        PCKh[k] = (PCKh[k] * count[k] +
                                   acc_PCKh[k]) / (count[k] + 1)
                        count[k] += 1

            mAP = AP[1:].sum() / (self.numClasses)
            mPCK = PCK[1:].sum() / (self.numClasses)
            mPCKh = PCKh[1:].sum() / (self.numClasses)

            val_loss += loss

            tbar.set_description('Val   loss: %.6f' %
                                 (val_loss / ((i + 1) * self.batch_size)))

        printAccuracies(mAP, AP, mPCKh, PCKh, mPCK, PCK, self.dataset)

        PCKhAvg = PCKh.sum() / (self.numClasses + 1)
        PCKAvg = PCK.sum() / (self.numClasses + 1)

        if mAP > self.isBest:
            self.isBest = mAP
            save_checkpoint({'state_dict': self.model.state_dict()},
                            self.isBest, self.args.model_name)

        if mPCKh > self.bestPCKh:
            self.bestPCKh = mPCKh
        if mPCK > self.bestPCK:
            self.bestPCK = mPCK

        print("Best AP = %.2f%%; PCK = %2.2f%%; PCKh = %2.2f%%" %
              (self.isBest * 100, self.bestPCK * 100, self.bestPCKh * 100))
Beispiel #6
0
    def validation(self, epoch):
        self.model.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        val_loss = 0.0
        
        AP    = np.zeros(self.numClasses+1)
        PCK   = np.zeros(self.numClasses+1)
        PCKh  = np.zeros(self.numClasses+1)
        count = np.zeros(self.numClasses+1)

        cnt = 0
        for i, (input, heatmap, centermap, img_path, limbsmap) in enumerate(tbar):

            cnt += 1

            input_var     =      input.cuda()
            heatmap_var   =    heatmap.cuda()
            limbs_var     =   limbsmap.cuda()

            self.optimizer.zero_grad()

            heat, limbs = self.model(input_var)
            loss_heat   = self.criterion(heat,  heatmap_var)

            loss = loss_heat

            val_loss += loss_heat.item()

            tbar.set_description('Val   loss: %.6f' % (val_loss / ((i + 1)*self.batch_size)))

            acc, acc_PCK, acc_PCKh, cnt, pred, visible = evaluate.accuracy(heat.detach().cpu().numpy(), heatmap_var.detach().cpu().numpy(),0.2,0.5, self.dataset)

            AP[0]     = (AP[0]  *i + acc[0])      / (i + 1)
            PCK[0]    = (PCK[0] *i + acc_PCK[0])  / (i + 1)
            PCKh[0]   = (PCKh[0]*i + acc_PCKh[0]) / (i + 1)

            for j in range(1,self.numClasses+1):
                if visible[j] == 1:
                    AP[j]     = (AP[j]  *count[j] + acc[j])      / (count[j] + 1)
                    PCK[j]    = (PCK[j] *count[j] + acc_PCK[j])  / (count[j] + 1)
                    PCKh[j]   = (PCKh[j]*count[j] + acc_PCKh[j]) / (count[j] + 1)
                    count[j] += 1

            mAP     =   AP[1:].sum()/(self.numClasses)
            mPCK    =  PCK[1:].sum()/(self.numClasses)
            mPCKh   = PCKh[1:].sum()/(self.numClasses)
	
        printAccuracies(mAP, AP, mPCKh, PCKh, mPCK, PCK, self.dataset)
            
        PCKhAvg = PCKh.sum()/(self.numClasses+1)
        PCKAvg  =  PCK.sum()/(self.numClasses+1)

        if mAP > self.isBest:
            self.isBest = mAP
            save_checkpoint({'state_dict': self.model.state_dict()}, self.isBest, self.args.model_name)
            print("Model saved to "+self.args.model_name)

        if mPCKh > self.bestPCKh:
            self.bestPCKh = mPCKh
        if mPCK > self.bestPCK:
            self.bestPCK = mPCK

        print("Best AP = %.2f%%; PCK = %2.2f%%; PCKh = %2.2f%%" % (self.isBest*100, self.bestPCK*100,self.bestPCKh*100))
Beispiel #7
0
def val(args):

    with open(args.test_box, 'r') as f:
        gt_boxes = [
            map(float,
                x.strip().split(' ')[2:]) for x in f.readlines()
        ]
    gt_boxes = [(box[0], box[1], box[0] + box[2] - 1, box[1] + box[3] - 1)
                for box in gt_boxes]

    # meters
    top1_clsacc = AverageMeter()
    top1_locerr = AverageMeter()
    top5_clsacc = AverageMeter()
    top5_locerr = AverageMeter()
    top1_clsacc.reset()
    top1_locerr.reset()
    top5_clsacc.reset()
    top5_locerr.reset()

    # get model
    model = get_model(args)
    model.eval()

    # get data
    _, valcls_loader, valloc_loader = data_loader(args, test_path=True)
    assert len(valcls_loader) == len(valloc_loader), \
        'Error! Different size for two dataset: loc({}), cls({})'.format(len(valloc_loader), len(valcls_loader))

    # testing
    VISLOC = True
    if VISLOC:
        # show_idxs = np.arange(20)
        np.random.seed(2333)
        show_idxs = np.arange(len(valcls_loader))
        np.random.shuffle(show_idxs)
        show_idxs = show_idxs[:20]

    # evaluation classification task
    pred_prob = []
    for dat in tqdm(valcls_loader):
        # parse data
        img_path, img, label_in = dat
        if args.tencrop == 'True':
            bs, ncrops, c, h, w = img.size()
            img = img.view(-1, c, h, w)
            label_input = label_in.repeat(10, 1)
            label = label_input.view(-1)
        else:
            label = label_in

        # forward pass
        img, label = img.cuda(), label.cuda()
        img_var, label_var = Variable(img), Variable(label)
        logits = model(img_var)

        # get classification prob
        logits0 = logits
        logits0 = F.softmax(logits0, dim=1)
        if args.tencrop == 'True':
            logits0 = logits0.view(1, ncrops, -1).mean(1)
        pred_prob.append(logits0.cpu().data.numpy())

        # update result record
        prec1_1, prec5_1 = evaluate.accuracy(logits0.cpu().data,
                                             label_in.long(),
                                             topk=(1, 5))
        top1_clsacc.update(prec1_1[0].numpy(), img.size()[0])
        top5_clsacc.update(prec5_1[0].numpy(), img.size()[0])

    pred_prob = np.concatenate(pred_prob, axis=0)
    # with open('pred_prob.pkl', 'w') as f:
    #     cPickle.dump(pred_prob, f)
    print('== cls err')
    print('Top1: {:.2f} Top5: {:.2f}\n'.format(100.0 - top1_clsacc.avg,
                                               100.0 - top5_clsacc.avg))

    # with open('pred_prob.pkl', 'r') as f:
    #     pred_prob = cPickle.load(f)
    # evaluation localization task
    thresholds = map(float, args.threshold.split(','))
    thresholds = list(thresholds)
    for th in thresholds:
        top1_locerr.reset()
        top5_locerr.reset()
        for idx, dat in tqdm(enumerate(valloc_loader)):
            # parse data
            img_path, img, label = dat

            # forward pass
            img, label = img.cuda(), label.cuda()
            img_var, label_var = Variable(img), Variable(label)
            logits = model(img_var)

            # get localization boxes
            cam_map = model.module.get_cam_maps()  # not normalized
            cam_map = cam_map.cpu().data.numpy()
            top_boxes, top_maps = get_topk_boxes(pred_prob[idx, :],
                                                 cam_map,
                                                 img_path[0],
                                                 args.input_size,
                                                 args.crop_size,
                                                 topk=(1, 5),
                                                 threshold=th,
                                                 mode='union')
            top1_box, top5_boxes = top_boxes

            # update result record
            locerr_1, locerr_5 = evaluate.locerr(
                (top1_box, top5_boxes),
                label.cpu().data.long().numpy(),
                gt_boxes[idx],
                topk=(1, 5))
            top1_locerr.update(locerr_1, img.size()[0])
            top5_locerr.update(locerr_5, img.size()[0])
            if VISLOC:
                if idx in show_idxs:
                    save_im_heatmap_box(
                        img_path[0],
                        top_maps,
                        top5_boxes,
                        '../figs/',
                        gt_label=label.cpu().data.long().numpy(),
                        gt_box=gt_boxes[idx])
        print('=========== threshold: {} ==========='.format(th))
        print('== loc err')
        print('Top1: {:.2f} Top5: {:.2f}\n'.format(top1_locerr.avg,
                                                   top5_locerr.avg))
Beispiel #8
0
def train(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str

    # for reproducibility
    if args.seed is not None:
        np.random.seed(args.seed)
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = False
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    else:
        cudnn.benchmark = True

    print('Running parameters:\n')
    print(json.dumps(vars(args), indent=4, separators=(',', ':')))

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)
    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        config = json.dumps(vars(args), indent=4, separators=(',', ':'))
        fw.write(config)

    log_head = '#epoch \t loss \t pred@1 \t pred@5'
    batch_time = AverageMeter()
    losses = AverageMeter()
    if args.ram:
        losses_ra = AverageMeter()
        log_head += 'loss_ra \t '
    log_head += '\n'
    with open(os.path.join(args.snapshot_dir, 'train_record.csv'), 'a') as fw:
        fw.write(log_head)
    top1 = AverageMeter()
    top5 = AverageMeter()
    args.device = torch.device('cuda') if args.gpus[0] >= 0 else torch.device(
        'cpu')
    model, optimizer = get_model(args)

    model.train()
    train_loader = data_loader(args)

    # construct writer
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    writer = SummaryWriter(log_dir=args.log_dir)

    total_epoch = args.epoch
    global_counter = args.global_counter
    current_epoch = args.current_epoch
    end = time.time()
    max_iter = total_epoch * len(train_loader)
    print('Max iter:', max_iter)
    while current_epoch < total_epoch:
        model.train()
        losses.reset()
        if args.ram:
            losses_ra.reset()

        top1.reset()
        top5.reset()
        batch_time.reset()
        res = my_optim.reduce_lr(args, optimizer, current_epoch)

        if res:
            with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                      'a') as fw:
                for g in optimizer.param_groups:
                    out_str = 'Epoch:%d, %f\n' % (current_epoch, g['lr'])
                    fw.write(out_str)

        steps_per_epoch = len(train_loader)
        for idx, dat in enumerate(train_loader):
            img_path, img, label = dat
            global_counter += 1
            img, label = img.to(args.device), label.to(args.device)

            logits, _, _ = model(img)

            loss_val, loss_ra = model.module.get_loss(logits,
                                                      label,
                                                      epoch=current_epoch,
                                                      ram_start=args.ram_start)

            # write into tensorboard
            writer.add_scalar('loss_val', loss_val, global_counter)

            # network parameter update
            optimizer.zero_grad()
            # if args.mixp:
            #     with amp.scale_loss(loss_val, optimizer) as scaled_loss:
            #         scaled_loss.backward()
            # else:
            loss_val.backward()
            optimizer.step()

            logits = torch.mean(torch.mean(logits, dim=2), dim=2)
            if not args.onehot == 'True':
                prec1, prec5 = evaluate.accuracy(logits.data,
                                                 label.long(),
                                                 topk=(1, 5))
                top1.update(prec1[0], img.size()[0])
                top5.update(prec5[0], img.size()[0])

            losses.update(loss_val.data, img.size()[0])
            if args.ram:
                losses_ra.update(loss_ra.data, img.size()[0])
            batch_time.update(time.time() - end)

            end = time.time()
            if global_counter % args.disp_interval == 0:
                # Calculate ETA
                eta_seconds = (
                    (total_epoch - current_epoch) * steps_per_epoch +
                    (steps_per_epoch - idx)) * batch_time.avg
                eta_str = "{:0>8}".format(
                    str(datetime.timedelta(seconds=int(eta_seconds))))
                eta_seconds_epoch = steps_per_epoch * batch_time.avg
                eta_str_epoch = "{:0>8}".format(
                    str(datetime.timedelta(seconds=int(eta_seconds_epoch))))
                log_output= 'Epoch: [{0}][{1}/{2}] \t ' \
                            'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t ' \
                            'ETA {eta_str}({eta_str_epoch})\t ' \
                            'Loss {loss.val:.4f} ({loss.avg:.4f})\t ' \
                            'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t ' \
                            'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(current_epoch,
                            global_counter % len(train_loader), len(train_loader), batch_time=batch_time,
                            eta_str=eta_str, eta_str_epoch=eta_str_epoch, loss=losses, top1=top1, top5=top5)
                if args.ram:
                    log_output += 'Loss_ra {loss_ra.val:.4f} ({loss_ra.avg:.4f})\t'.format(
                        loss_ra=losses_ra)
                print(log_output)
                writer.add_scalar('top1', top1.avg, global_counter)
                writer.add_scalar('top5', top5.avg, global_counter)

        current_epoch += 1
        if current_epoch % 10 == 0:
            save_checkpoint(args, {
                'epoch': current_epoch,
                'arch': args.arch,
                'global_counter': global_counter,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
                            is_best=False,
                            filename='%s_epoch_%d.pth.tar' %
                            (args.dataset, current_epoch))

        with open(os.path.join(args.snapshot_dir, 'train_record.csv'),
                  'a') as fw:
            log_output = '{} \t {:.4f} \t {:.3f} \t {:.3f} \t'.format(
                current_epoch, losses.avg, top1.avg, top5.avg)
            if args.ram:
                log_output += '{:.4f}'.format(losses_ra.avg)
            log_output += '\n'
            fw.write(log_output)

        losses.reset()
        if args.ram:
            losses_ra.reset()
        top1.reset()
        top5.reset()
Beispiel #9
0
def val(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus_str

    print('Running parameters:\n')
    print(json.dumps(vars(args), indent=4, separators=(',', ':')))

    if not os.path.exists(args.snapshot_dir):
        os.mkdir(args.snapshot_dir)

    if args.dataset == 'ilsvrc':
        gt_boxes = []
        img_name = []
        with open(args.test_box, 'r') as f:
            for x in f.readlines():
                x = x.strip().split(' ')
                if len(x[1:]) % 4 == 0:
                    gt_boxes.append(list(map(float, x[1:])))
                    img_name.append(
                        os.path.join(args.img_dir,
                                     x[0].replace('.xml', '.JPEG')))
                else:
                    print('Wrong gt bboxes.')
    elif args.dataset == 'cub':
        with open(args.test_box, 'r') as f:
            gt_boxes = [
                list(map(float,
                         x.strip().split(' ')[2:])) for x in f.readlines()
            ]
        gt_boxes = [(box[0], box[1], box[0] + box[2] - 1, box[1] + box[3] - 1)
                    for box in gt_boxes]
    else:
        print('Wrong dataset.')
    # meters
    top1_clsacc = AverageMeter()
    top5_clsacc = AverageMeter()
    top1_clsacc.reset()
    top5_clsacc.reset()

    loc_err = {}
    for th in args.threshold:
        loc_err['top1_locerr_{}'.format(th)] = AverageMeter()
        loc_err['top1_locerr_{}'.format(th)].reset()
        loc_err['top5_locerr_{}'.format(th)] = AverageMeter()
        loc_err['top5_locerr_{}'.format(th)].reset()
        loc_err['gt_known_locerr_{}'.format(th)] = AverageMeter()
        loc_err['gt_known_locerr_{}'.format(th)].reset()
        for err in [
                'right', 'cls_wrong', 'mins_wrong', 'part_wrong', 'more_wrong',
                'other'
        ]:
            loc_err['top1_locerr_{}_{}'.format(err, th)] = AverageMeter()
            loc_err['top1_locerr_{}_{}'.format(err, th)].reset()
        if args.scg:
            loc_err['top1_locerr_scg_{}'.format(th)] = AverageMeter()
            loc_err['top1_locerr_scg_{}'.format(th)].reset()
            loc_err['top5_locerr_scg_{}'.format(th)] = AverageMeter()
            loc_err['top5_locerr_scg_{}'.format(th)].reset()
            loc_err['gt_known_locerr_scg_{}'.format(th)] = AverageMeter()
            loc_err['gt_known_locerr_scg_{}'.format(th)].reset()
            for err in [
                    'right', 'cls_wrong', 'mins_wrong', 'part_wrong',
                    'more_wrong', 'other'
            ]:
                loc_err['top1_locerr_scg_{}_{}'.format(err,
                                                       th)] = AverageMeter()
                loc_err['top1_locerr_scg_{}_{}'.format(err, th)].reset()
    # get model
    model = get_model(args)
    model.eval()
    # get data
    valcls_loader, valloc_loader = data_loader(args,
                                               test_path=True,
                                               train=False)
    assert len(valcls_loader) == len(valloc_loader), \
        'Error! Different size for two dataset: loc({}), cls({})'.format(len(valloc_loader), len(valcls_loader))

    # testing
    if args.debug:
        # show_idxs = np.arange(20)
        np.random.seed(2333)
        show_idxs = np.arange(len(valcls_loader))
        np.random.shuffle(show_idxs)
        show_idxs = show_idxs[:]

    # evaluation classification task

    for idx, (dat_cls,
              dat_loc) in tqdm(enumerate(zip(valcls_loader, valloc_loader))):
        # parse data
        img_path, img, label_in = dat_cls
        if args.tencrop == 'True':
            bs, ncrops, c, h, w = img.size()
            img = img.view(-1, c, h, w)

        # forward pass
        args.device = torch.device(
            'cuda') if args.gpus[0] >= 0 else torch.device('cpu')
        img = img.to(args.device)

        if args.vis_feat:
            if idx in show_idxs:
                _, img_loc, label = dat_loc
                _ = model(img_loc)
                vis_feature(model.module.feat4,
                            img_path[0],
                            args.vis_dir,
                            layer='feat4')
                vis_feature(model.module.feat5,
                            img_path[0],
                            args.vis_dir,
                            layer='feat5')
                vis_feature(model.module.cls_map,
                            img_path[0],
                            args.vis_dir,
                            layer='cls_map')
            continue
        if args.vis_var:
            if idx in show_idxs:
                _, img_loc, label = dat_loc
                logits, _, _, _, _ = model(img_loc)
                cls_logits = F.softmax(logits, dim=1)
                var_logits = torch.var(cls_logits, dim=1).squeeze()
                logits_cls = logits[0, label.long(), ...]
                vis_var(var_logits,
                        logits_cls,
                        img_path[0],
                        args.vis_dir,
                        net='vgg_s10_loc_.4_.7_fpn_l4_var_cls')
            continue
        with torch.no_grad():
            logits, _, _ = model(img)
            cls_logits = torch.mean(torch.mean(logits, dim=2), dim=2)
            cls_logits = F.softmax(cls_logits, dim=1)
            if args.tencrop == 'True':
                cls_logits = cls_logits.view(1, ncrops, -1).mean(1)

            prec1_1, prec5_1 = evaluate.accuracy(cls_logits.cpu().data,
                                                 label_in.long(),
                                                 topk=(1, 5))
            top1_clsacc.update(prec1_1[0].numpy(), img.size()[0])
            top5_clsacc.update(prec5_1[0].numpy(), img.size()[0])

        _, img_loc, label = dat_loc
        with torch.no_grad():
            logits, sc_maps_fo, sc_maps_so = model(img_loc, scg_flag=args.scg)
            loc_map = F.relu(logits)

        for th in args.threshold:
            locerr_1, locerr_5, gt_known_locerr, top_maps, top5_boxes, gt_known_maps, top1_wrong_detail = \
                eval_loc(cls_logits, loc_map, img_path[0], label, gt_boxes[idx], topk=(1, 5), threshold=th,
                         mode='union', iou_th=args.iou_th)
            loc_err['top1_locerr_{}'.format(th)].update(
                locerr_1,
                img_loc.size()[0])
            loc_err['top5_locerr_{}'.format(th)].update(
                locerr_5,
                img_loc.size()[0])
            loc_err['gt_known_locerr_{}'.format(th)].update(
                gt_known_locerr,
                img_loc.size()[0])

            cls_wrong, multi_instances, region_part, region_more, region_wrong = top1_wrong_detail
            right = 1 - (cls_wrong + multi_instances + region_part +
                         region_more + region_wrong)
            loc_err['top1_locerr_right_{}'.format(th)].update(
                right,
                img_loc.size()[0])
            loc_err['top1_locerr_cls_wrong_{}'.format(th)].update(
                cls_wrong,
                img_loc.size()[0])
            loc_err['top1_locerr_mins_wrong_{}'.format(th)].update(
                multi_instances,
                img_loc.size()[0])
            loc_err['top1_locerr_part_wrong_{}'.format(th)].update(
                region_part,
                img_loc.size()[0])
            loc_err['top1_locerr_more_wrong_{}'.format(th)].update(
                region_more,
                img_loc.size()[0])
            loc_err['top1_locerr_other_{}'.format(th)].update(
                region_wrong,
                img_loc.size()[0])
            if args.debug and idx in show_idxs and (th == args.threshold[0]):
                top1_wrong_detail_dir = 'cls_{}-mins_{}-rpart_{}-rmore_{}-rwrong_{}'.format(
                    cls_wrong, multi_instances, region_part, region_more,
                    region_wrong)
                debug_dir = os.path.join(
                    args.debug_dir, top1_wrong_detail_dir
                ) if args.debug_detail else args.debug_dir
                save_im_heatmap_box(img_path[0],
                                    top_maps,
                                    top5_boxes,
                                    debug_dir,
                                    gt_label=label.data.long().numpy(),
                                    gt_box=gt_boxes[idx],
                                    epoch=args.current_epoch,
                                    threshold=th)

            if args.scg:
                sc_maps = []
                if args.scg_com:
                    for sc_map_fo_i, sc_map_so_i in zip(
                            sc_maps_fo, sc_maps_so):
                        if (sc_map_fo_i is not None) and (sc_map_so_i
                                                          is not None):
                            sc_map_i = torch.max(
                                sc_map_fo_i, args.scg_so_weight * sc_map_so_i)
                            sc_map_i = sc_map_i / (torch.sum(
                                sc_map_i, dim=1, keepdim=True) + 1e-10)
                            sc_maps.append(sc_map_i)
                elif args.scg_fo:
                    sc_maps = sc_maps_fo
                else:
                    sc_maps = sc_maps_so
                locerr_1_scg, locerr_5_scg ,gt_known_locerr_scg, top_maps_scg, top5_boxes_scg, top1_wrong_detail_scg = \
                    eval_loc_scg(cls_logits, top_maps, gt_known_maps, sc_maps[-1]+sc_maps[-2], img_path[0], label,
                                                     gt_boxes[idx], topk=(1, 5), threshold=th, mode='union',
                                                     fg_th=args.scg_fg_th, bg_th=args.scg_bg_th,iou_th=args.iou_th,
                                                      sc_maps_fo= None)
                loc_err['top1_locerr_scg_{}'.format(th)].update(
                    locerr_1_scg,
                    img_loc.size()[0])
                loc_err['top5_locerr_scg_{}'.format(th)].update(
                    locerr_5_scg,
                    img_loc.size()[0])
                loc_err['gt_known_locerr_scg_{}'.format(th)].update(
                    gt_known_locerr_scg,
                    img_loc.size()[0])

                cls_wrong_scg, multi_instances_scg, region_part_scg, region_more_scg, region_wrong_scg = top1_wrong_detail_scg
                right_scg = 1 - (cls_wrong_scg + multi_instances_scg +
                                 region_part_scg + region_more_scg +
                                 region_wrong_scg)
                loc_err['top1_locerr_scg_right_{}'.format(th)].update(
                    right_scg,
                    img_loc.size()[0])
                loc_err['top1_locerr_scg_cls_wrong_{}'.format(th)].update(
                    cls_wrong_scg,
                    img_loc.size()[0])
                loc_err['top1_locerr_scg_mins_wrong_{}'.format(th)].update(
                    multi_instances_scg,
                    img_loc.size()[0])
                loc_err['top1_locerr_scg_part_wrong_{}'.format(th)].update(
                    region_part_scg,
                    img_loc.size()[0])
                loc_err['top1_locerr_scg_more_wrong_{}'.format(th)].update(
                    region_more_scg,
                    img_loc.size()[0])
                loc_err['top1_locerr_scg_other_{}'.format(th)].update(
                    region_wrong_scg,
                    img_loc.size()[0])

                if args.debug and idx in show_idxs and (th
                                                        == args.threshold[0]):
                    top1_wrong_detail_dir = 'cls_{}-mins_{}-rpart_{}-rmore_{}-rwrong_{}_scg'.format(
                        cls_wrong_scg, multi_instances_scg, region_part_scg,
                        region_more_scg, region_wrong_scg)
                    debug_dir = os.path.join(
                        args.debug_dir, top1_wrong_detail_dir
                    ) if args.debug_detail else args.debug_dir
                    save_im_heatmap_box(img_path[0],
                                        top_maps_scg,
                                        top5_boxes_scg,
                                        debug_dir,
                                        gt_label=label.data.long().numpy(),
                                        gt_box=gt_boxes[idx],
                                        epoch=args.current_epoch,
                                        threshold=th,
                                        suffix='scg')

                    save_im_sim(img_path[0],
                                sc_maps_fo,
                                debug_dir,
                                gt_label=label.data.long().numpy(),
                                epoch=args.current_epoch,
                                suffix='fo')
                    save_im_sim(img_path[0],
                                sc_maps_so,
                                debug_dir,
                                gt_label=label.data.long().numpy(),
                                epoch=args.current_epoch,
                                suffix='so')
                    save_im_sim(img_path[0],
                                sc_maps_fo[-2] + sc_maps_fo[-1],
                                debug_dir,
                                gt_label=label.data.long().numpy(),
                                epoch=args.current_epoch,
                                suffix='fo_45')
                    # save_im_sim(img_path[0], aff_maps_so[-2] + aff_maps_so[-1], debug_dir,
                    #             gt_label=label.data.long().numpy(),
                    #             epoch=args.current_epoch, suffix='so_45')
                    # # save_im_sim(img_path[0], aff_maps, debug_dir, gt_label=label.data.long().numpy(),
                    # #             epoch=args.current_epoch, suffix='com')
                    save_sim_heatmap_box(img_path[0],
                                         top_maps,
                                         debug_dir,
                                         gt_label=label.data.long().numpy(),
                                         sim_map=sc_maps_fo[-2] +
                                         sc_maps_fo[-1],
                                         epoch=args.current_epoch,
                                         threshold=th,
                                         suffix='aff_fo_f45_cam',
                                         fg_th=args.scg_fg_th,
                                         bg_th=args.scg_bg_th)
                    # save_sim_heatmap_box(img_path[0], top_maps, debug_dir, gt_label=label.data.long().numpy(),
                    #                      sim_map=aff_maps_so[-2] + aff_maps_so[-1], epoch=args.current_epoch, threshold=th,
                    #                      suffix='aff_so_f5_cam', fg_th=args.scg_fg_th, bg_th=args.scg_bg_th)
                    # save_sim_heatmap_box(img_path[0], df_top_maps, debug_dir, gt_label=label.data.long().numpy(),
                    #                      sim_map=aff_maps_so[-2], epoch=args.current_epoch, threshold=th,
                    #                      suffix='aff_so_f4_cam',fg_th=args.nl_fg_th, bg_th=args.nl_bg_th)
                    # save_sim_heatmap_box(img_path[0], df_top_maps, debug_dir, gt_label=label.data.long().numpy(),
                    #                      sim_map=aff_maps_so[-1], epoch=args.current_epoch, threshold=th,
                    #                      suffix='aff_so_f5_cam', fg_th=args.nl_fg_th, bg_th=args.nl_bg_th)
                    # save_sim_heatmap_box(img_path[0], df_top_maps, debug_dir, gt_label=label.data.long().numpy(),
                    #                      sim_map=aff_maps[-2:],
                    #                      epoch=args.current_epoch, threshold=th, suffix='aff_com_cam',fg_th=args.nl_fg_th, bg_th=args.nl_bg_th)

    print('== cls err')
    print('Top1: {:.2f} Top5: {:.2f}\n'.format(100.0 - top1_clsacc.avg,
                                               100.0 - top5_clsacc.avg))
    for th in args.threshold:
        print('=========== threshold: {} ==========='.format(th))
        print('== loc err')
        print('CAM-Top1: {:.2f} Top5: {:.2f}\n'.format(
            loc_err['top1_locerr_{}'.format(th)].avg,
            loc_err['top5_locerr_{}'.format(th)].avg))
        print('CAM-Top1_err: {} {} {} {} {} {}\n'.format(
            loc_err['top1_locerr_right_{}'.format(th)].sum,
            loc_err['top1_locerr_cls_wrong_{}'.format(th)].sum,
            loc_err['top1_locerr_mins_wrong_{}'.format(th)].sum,
            loc_err['top1_locerr_part_wrong_{}'.format(th)].sum,
            loc_err['top1_locerr_more_wrong_{}'.format(th)].sum,
            loc_err['top1_locerr_other_{}'.format(th)].sum))
        if args.scg:
            print('SCG-Top1: {:.2f} Top5: {:.2f}\n'.format(
                loc_err['top1_locerr_scg_{}'.format(th)].avg,
                loc_err['top5_locerr_scg_{}'.format(th)].avg))
            print('SCG-Top1_err: {} {} {} {} {} {}\n'.format(
                loc_err['top1_locerr_scg_right_{}'.format(th)].sum,
                loc_err['top1_locerr_scg_cls_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_scg_mins_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_scg_part_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_scg_more_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_scg_other_{}'.format(th)].sum))
        print('== Gt-Known loc err')
        print('CAM-Top1: {:.2f} \n'.format(
            loc_err['gt_known_locerr_{}'.format(th)].avg))
        if args.scg:
            print('SCG-Top1: {:.2f} \n'.format(
                loc_err['gt_known_locerr_scg_{}'.format(th)].avg))

    setting = args.debug_dir.split('/')[-1]
    results_log_name = '{}_results.log'.format(setting)
    result_log = os.path.join(args.snapshot_dir, results_log_name)
    with open(result_log, 'a') as fw:
        fw.write('== cls err ')
        fw.write('Top1: {:.2f} Top5: {:.2f}\n'.format(100.0 - top1_clsacc.avg,
                                                      100.0 - top5_clsacc.avg))
        for th in args.threshold:
            fw.write('=========== threshold: {} ===========\n'.format(th))
            fw.write('== loc err ')
            fw.write('CAM-Top1: {:.2f} Top5: {:.2f}\n'.format(
                loc_err['top1_locerr_{}'.format(th)].avg,
                loc_err['top5_locerr_{}'.format(th)].avg))
            fw.write('CAM-Top1_err: {} {} {} {} {} {}\n'.format(
                loc_err['top1_locerr_right_{}'.format(th)].sum,
                loc_err['top1_locerr_cls_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_mins_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_part_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_more_wrong_{}'.format(th)].sum,
                loc_err['top1_locerr_other_{}'.format(th)].sum))
            if args.scg:
                fw.write('SCG-Top1: {:.2f} Top5: {:.2f}\n'.format(
                    loc_err['top1_locerr_scg_{}'.format(th)].avg,
                    loc_err['top5_locerr_scg_{}'.format(th)].avg))
                fw.write('SCG-Top1_err: {} {} {} {} {} {}\n'.format(
                    loc_err['top1_locerr_scg_right_{}'.format(th)].sum,
                    loc_err['top1_locerr_scg_cls_wrong_{}'.format(th)].sum,
                    loc_err['top1_locerr_scg_mins_wrong_{}'.format(th)].sum,
                    loc_err['top1_locerr_scg_part_wrong_{}'.format(th)].sum,
                    loc_err['top1_locerr_scg_more_wrong_{}'.format(th)].sum,
                    loc_err['top1_locerr_scg_other_{}'.format(th)].sum))
            fw.write('== Gt-Known loc err ')
            fw.write('CAM-Top1: {:.2f} \n'.format(
                loc_err['top1_locerr_{}'.format(th)].avg))
            if args.scg:
                fw.write('SCG-Top1: {:.2f} \n'.format(
                    loc_err['gt_known_locerr_scg_{}'.format(th)].avg))