def cls_infer():
    # load model
    model = get_multi_vgg_model(args=args)
    model.load_state_dict(
        torch.load(args.save_best_model_path, map_location='cpu'))
    model = model.eval()
    if args.cuda:
        model.cuda()

    val_dataset = Cub_Loader(args=args, mode='val')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False)

    val_result = []
    with torch.no_grad():
        for step, (img_id, img, label,
                   bbox) in enumerate(tqdm(val_dataloader)):
            if args.cuda:
                img = img.cuda()
                label = label.cuda(non_blocking=True)

            b, crop, c, w, h = img.size()
            img = img.view(b * crop, c, w, h)
            logits1, logits2, logits3 = model.forward(img)
            logits3 = logits3.view(b, crop, args.class_nums)
            logits3 = torch.mean(logits3, dim=1)
            target_cls = torch.argmax(logits3, dim=-1)

            val_result.append(target_cls.cpu().numpy() == label.cpu().numpy())

        val_acc = np.concatenate(val_result)
        print('val acc:{}'.format(np.mean(val_acc)))
def vgg_infer_second():
    # load model
    model = get_multi_vgg_model(args=args, inference=True)
    model.load_state_dict(
        torch.load(args.save_best_model_path, map_location='cpu'))
    model.eval()
    if args.cuda:
        model.cuda()

    val_dataset = Cub_Loader(args=args, mode='test')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

    # val_cls_list = pickle.load(open('save/train/val_cls.pkl', 'rb'))

    iou_result = []
    cls_result = []
    bbox_result = []
    logits_result = dict()
    cam3_predict_idx = []
    cam3_predict_result = []
    cam3_gt_result = []
    for step, (img_id, img, label, bbox) in enumerate(tqdm(val_dataloader)):
        if args.cuda:
            img_id = img_id[0].item()
            img = img.cuda()
            label = label.cuda(non_blocking=True)
            bbox = [float(x) for x in bbox[0].split(' ')]
            if args.one_obj and len(bbox) > 4:
                continue

        b, crop, c, w, h = img.size()

        img = img.view(b * crop, c, w, h)

        logits3, cam1, cam2, cam3 = model.forward(img)
        prediction_cls = torch.argmax(torch.mean(logits3, dim=0), -1)
        raw_img = get_raw_imgs_by_id(args, [img_id], val_dataset)[0]
        max_value_in_cam1 = torch.max(cam1).item()
        max_value_in_cam2 = torch.max(cam2).item()
        max_value_in_cam3 = torch.max(cam3).item()

        gt_cam = cam3.view(b, crop, 200, 28, 28)[:, :, label.item(), :, :]
        cam1 = cam1.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]
        cam2 = cam2.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]
        cam3 = cam3.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]

        up_cam1 = F.upsample(cam1,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()
        up_cam2 = F.upsample(cam2,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()
        up_cam3 = F.upsample(cam3,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()
        up_gt_cam = F.upsample(gt_cam,
                               size=(224, 224),
                               mode='bilinear',
                               align_corners=False).detach()

        merge_cam1 = model.merge_ten_crop_cam(up_cam1)
        merge_cam2 = model.merge_ten_crop_cam(up_cam2)
        merge_cam3 = model.merge_ten_crop_cam(up_cam3)
        merge_gt_cam = model.merge_ten_crop_cam(up_gt_cam)

        merge_cam1 = F.upsample(merge_cam1.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()
        merge_cam2 = F.upsample(merge_cam2.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()
        merge_cam3 = F.upsample(merge_cam3.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()
        merge_gt_cam = F.upsample(merge_gt_cam.unsqueeze(0).unsqueeze(0),
                                  size=(raw_img.size[1], raw_img.size[0]),
                                  mode='bilinear',
                                  align_corners=False).squeeze()

        final_cam1 = model.norm_cam_2_binary(merge_cam1,
                                             thd=max_value_in_cam1 * 0.8)
        final_cam2 = model.norm_cam_2_binary(merge_cam2,
                                             thd=max_value_in_cam2 * 0.8)
        final_cam3 = model.norm_cam_2_binary(merge_cam3,
                                             thd=max_value_in_cam3 * 0.8)
        final_gt_cam = model.norm_cam_2_binary(merge_gt_cam,
                                               thd=max_value_in_cam3 * 0.8)

        sum_cam = final_cam1 + final_cam2 + final_cam3
        sum_cam[sum_cam > 1] = 1
        max_final_cam = get_max_binary_area(sum_cam.detach().cpu().numpy())

        result_bbox = get_bbox_from_binary_cam(max_final_cam)
        result_iou = calculate_iou(result_bbox, bbox)

        iou_result.append(result_iou)
        cls_result.append(prediction_cls.item() == label.item())
        bbox_result.append([
            result_bbox['x1'], result_bbox['y1'], result_bbox['x2'],
            result_bbox['y2']
        ])
        logits_result[img_id] = F.softmax(
            torch.mean(logits3, dim=0).detach().cpu())
        cam3_predict_idx.append(prediction_cls.item())
        cam3_predict_result.append(final_cam3.cpu().numpy())
        cam3_gt_result.append(final_gt_cam.cpu().numpy())

    cut_loader = Cut_Cub_Loader(args=args,
                                bbox_result=bbox_result,
                                mode='test')
    cut_dataloader = DataLoader(cut_loader, batch_size=1, shuffle=False)

    second_cls_result = []
    for step, (img_id, img, label, bbox) in enumerate(tqdm(cut_dataloader)):
        if args.cuda:
            img_id = img_id[0].item()
            img = img.cuda()
            label = label.cuda(non_blocking=True)
            bbox = [float(x) for x in bbox[0].split(' ')]
            if args.one_obj and len(bbox) > 4:
                continue

        b, crop, c, w, h = img.size()

        img = img.view(b * crop, c, w, h)

        logits3, cam1, cam2, cam3 = model.forward(img)
        merge_logits = F.softmax(torch.mean(
            logits3, dim=0)).cpu() * logits_result[img_id]
        prediction_cls = torch.argmax(merge_logits, -1)
        raw_img = get_raw_imgs_by_id(args, [img_id], val_dataset)[0]

        max_value_in_cam3 = torch.max(cam3).item()

        before_cam3 = cam3.view(b, crop, 200, 28,
                                28)[:, :, cam3_predict_idx[step], :, :]
        cam3 = cam3.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]

        up_cam3 = F.upsample(cam3,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()
        before_up_cam3 = F.upsample(before_cam3,
                                    size=(224, 224),
                                    mode='bilinear',
                                    align_corners=False).detach()

        merge_cam3 = model.merge_ten_crop_cam(up_cam3)
        before_merge_cam3 = model.merge_ten_crop_cam(before_up_cam3)

        final_cam3 = model.norm_cam_2_binary(merge_cam3,
                                             thd=max_value_in_cam3 * 0.8)
        before_final_cam3 = model.norm_cam_2_binary(before_merge_cam3,
                                                    thd=max_value_in_cam3 *
                                                    0.8)

        if prediction_cls.item() == label.item(
        ) and cam3_predict_idx[step] != label.item():
            plot_different_figs(
                'save/imgs/before_after/{}.png'.format(img_id), [
                    raw_img, cam3_predict_result[step], cam3_gt_result[step],
                    final_cam3.cpu().numpy(),
                    before_final_cam3.cpu().numpy()
                ])

        second_cls_result.append(prediction_cls.item() == label.item())

    print('second cls:{}'.format(np.mean(second_cls_result)))
    print('second cls5:{}'.format(np.mean(second_cls5_result)))
    print('iou*:{}'.format(np.mean(np.array(iou_result) >= 0.5)))
    print('iou:{}'.format(
        np.mean(np.array(second_cls_result) * (np.array(iou_result) >= 0.5))))
def base_inception_infer_with_top5():
    # load model
    model = get_inception3_base_model(args=args, inference=True)
    model.load_state_dict(
        torch.load(args.save_best_model_path, map_location='cpu'))
    model.eval()
    if args.cuda:
        model.cuda()

    val_dataset = Cub_Loader(args=args, mode='test')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

    # val_cls_list = pickle.load(open('save/train/val_cls.pkl', 'rb'))

    iou_result = []
    cls_result = []
    cls5_result = []
    bbox_result = []
    logits_result = dict()
    for step, (img_id, img, label, bbox) in enumerate(tqdm(val_dataloader)):
        if args.cuda:
            img_id = img_id[0]
            img = img.cuda()
            label = label.cuda(non_blocking=True)
            bbox = [float(x) for x in bbox[0].split(' ')]
            if args.one_obj and len(bbox) > 4:
                continue

        b, crop, c, w, h = img.size()

        img = img.view(b * crop, c, w, h)

        logits3, cam3 = model.forward(img)
        prediction_cls = torch.argmax(torch.mean(logits3, dim=0), -1)
        prediction_cls5 = torch.argsort(torch.mean(logits3, dim=0))[-5:]
        raw_img = get_raw_imgs_by_id(args, [img_id], val_dataset)[0]
        max_value_in_cam3 = torch.max(cam3).item()

        cam3 = cam3.view(b, crop, 200, 40, 40)[:, :, prediction_cls, :, :]

        up_cam3 = F.upsample(cam3,
                             size=(321, 321),
                             mode='bilinear',
                             align_corners=False).detach()

        merge_cam3 = model.merge_ten_crop_cam(up_cam3)

        merge_cam3 = F.upsample(merge_cam3.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()

        final_cam3 = model.norm_cam_2_binary(merge_cam3,
                                             thd=max_value_in_cam3 * 0.4)

        sum_cam = final_cam3
        max_final_cam = get_max_binary_area(sum_cam.detach().cpu().numpy())

        result_bbox = get_bbox_from_binary_cam(max_final_cam)
        result_iou = calculate_iou(result_bbox, bbox)

        iou_result.append(result_iou)
        cls_result.append(prediction_cls.item() == label.item())
        cls5_result.append(label.item() in prediction_cls5.cpu().numpy())
        bbox_result.append([
            result_bbox['x1'], result_bbox['y1'], result_bbox['x2'],
            result_bbox['y2']
        ])
        logits_result[img_id] = F.softmax(
            torch.mean(logits3, dim=0).detach().cpu())

        if step % 100 == 0:
            print('cls:{}'.format(np.mean(cls_result)))
            print('cls5:{}'.format(np.mean(cls5_result)))
            print('iou*:{}'.format(np.mean(np.array(iou_result) >= 0.5)))
            print('iou:{}'.format(
                np.mean(np.array(cls_result) * (np.array(iou_result) >= 0.5))))
            print('iou5:{}'.format(
                np.mean(np.array(cls5_result) *
                        (np.array(iou_result) >= 0.5))))

    print('cls:{}'.format(np.mean(cls_result)))
    print('cls5:{}'.format(np.mean(cls5_result)))
    print('iou*:{}'.format(np.mean(np.array(iou_result) >= 0.5)))
    print('iou:{}'.format(
        np.mean(np.array(cls_result) * (np.array(iou_result) >= 0.5))))
    print('iou5:{}'.format(
        np.mean(np.array(cls5_result) * (np.array(iou_result) >= 0.5))))
def multi_loc_plot():
    # load model
    model = get_multi_vgg_model(args=args, inference=True)
    model.load_state_dict(
        torch.load(args.save_best_model_path, map_location='cpu'))
    model.eval()
    if args.cuda:
        model.cuda()

    val_dataset = Cub_Loader(args=args, mode='val')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=1,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

    # val_cls_list = pickle.load(open('save/train/val_cls.pkl', 'rb'))

    iou_result = []
    cls_result = []
    for step, (img_id, img, label, bbox) in enumerate(tqdm(val_dataloader)):
        if args.cuda:
            img_id = img_id[0].item()
            img = img.cuda()
            label = label.cuda(non_blocking=True)
            bbox = [float(x) for x in bbox[0].split(' ')]

        b, crop, c, w, h = img.size()

        img = img.view(b * crop, c, w, h)

        logits3, cam1, cam2, cam3 = model.forward(img)
        prediction_cls = torch.argmax(torch.mean(logits3, dim=0), -1)
        raw_img = get_raw_imgs_by_id(args, [img_id], val_dataset)[0]

        max_value_in_cam1 = torch.max(cam1).item()
        max_value_in_cam2 = torch.max(cam2).item()
        max_value_in_cam3 = torch.max(cam3).item()

        cam1 = cam1.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]
        cam2 = cam2.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]
        cam3 = cam3.view(b, crop, 200, 28, 28)[:, :, prediction_cls, :, :]

        up_cam1 = F.upsample(cam1,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()
        up_cam2 = F.upsample(cam2,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()
        up_cam3 = F.upsample(cam3,
                             size=(224, 224),
                             mode='bilinear',
                             align_corners=False).detach()

        merge_cam1 = model.merge_ten_crop_cam(up_cam1)
        merge_cam2 = model.merge_ten_crop_cam(up_cam2)
        merge_cam3 = model.merge_ten_crop_cam(up_cam3)

        merge_cam1 = F.upsample(merge_cam1.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()
        merge_cam2 = F.upsample(merge_cam2.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()
        merge_cam3 = F.upsample(merge_cam3.unsqueeze(0).unsqueeze(0),
                                size=(raw_img.size[1], raw_img.size[0]),
                                mode='bilinear',
                                align_corners=False).squeeze()

        final_cam1 = model.norm_cam_2_binary(merge_cam1,
                                             thd=max_value_in_cam1 * 0.8)
        final_cam2 = model.norm_cam_2_binary(merge_cam2,
                                             thd=max_value_in_cam2 * 0.8)
        final_cam3 = model.norm_cam_2_binary(merge_cam3,
                                             thd=max_value_in_cam3 * 0.8)

        sum_cam = final_cam1 + final_cam2 + final_cam3
        sum_cam[sum_cam > 1] = 1

        max_final_cam = get_max_binary_area(sum_cam.cpu().numpy())

        result_bbox = get_bbox_from_binary_cam(max_final_cam)

        if prediction_cls.item() == label.item():
            plot_different_figs(
                args='save/imgs/tmp_imgs/{}.png'.format(img_id),
                plot_list=[
                    raw_img,
                    draw_bbox_on_raw(raw_img.copy(), result_bbox, bbox),
                    final_cam1.cpu().numpy(),
                    final_cam2.cpu().numpy(),
                    final_cam3.cpu().numpy(),
                    (final_cam1 + final_cam2 + final_cam3).cpu().numpy(),
                    max_final_cam
                ])
Exemple #5
0
def vgg_main():
    # load model
    generator = get_vec_vgg_model(args=args)
    extractor = get_extractor(args=args, out_dim=2048)
    if args.cuda:
        generator = nn.DataParallel(generator).cuda()
        extractor = nn.DataParallel(extractor).cuda()

    # load data
    train_dataset = Cub_Loader(args=args, mode='train')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=args.num_workers)

    val_dataset = Cub_Loader(args=args, mode='val')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

    # load loss
    loss_func = torch.nn.CrossEntropyLoss()

    for epoch in range(args.epoch):
        # init model
        generator = generator.train()
        extractor = extractor.train()

        # init opt
        opt1 = get_finetune_optimizer(args, generator, epoch)
        opt2 = get_finetune_optimizer(args, extractor, epoch)

        # other init
        train_result = {'cls': [], 'loss_co': [], 'loss': []}

        for step, (img_id, img, label) in enumerate(tqdm(train_dataloader)):
            if args.cuda:
                img = img.cuda()
                label = label.cuda()

            # generator
            # logits: [batch, cls_num], fmap: [batch, channel, height, weight]
            logits, fmap = generator(img)
            loss_cls = loss_func(logits, label)

            # generate object and background
            objs, bgs = fmp_crop(img, fmap, label.unsqueeze(-1))

            # extract feature vector by extractor
            obj_vec = extractor(objs)
            bg_vec = extractor(bgs)

            # loss 2
            loss_co = co_atten_loss(obj_vec, bg_vec, label.unsqueeze(-1))

            # back
            loss = loss_cls + loss_co
            opt1.zero_grad()
            opt2.zero_grad()
            loss.backward()
            opt1.step()
            opt2.step()

            # log
            train_result['cls'].append(
                torch.argmax(logits, dim=-1).cpu().numpy() ==
                label.cpu().numpy())
            train_result['loss_co'].append(loss_co)
            train_result['loss'].append(loss)

        log_value('generator_cls',
                  np.mean(np.concatenate(train_result['cls'])), epoch)
        log_value('extractor_loss',
                  np.mean(np.concatenate(train_result['loss_co'])), epoch)
        log_value('total_loss',
                  np.mean(np.concatenate(train_result['loss_total'])), epoch)
Exemple #6
0
def inception_base_main():
    # load model
    model = get_inception3_base_model(args=args, pretrained=True)
    if args.cuda:
        model = nn.DataParallel(model).cuda()

    train_dataset = Cub_Loader(args=args, mode='train')
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=args.num_workers)

    val_dataset = Cub_Loader(args=args, mode='val')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

    # load param
    loss_func = torch.nn.CrossEntropyLoss()

    all_train_acc = []
    all_val_acc = []
    for epoch in range(args.epoch):
        model = model.train()

        opt = get_finetune_optimizer(args, model, epoch)

        train_result = {'cls3': []}
        train_loss = 0.
        for step, (img, label) in enumerate(tqdm(train_dataloader)):
            if args.cuda:
                img = img.cuda()
                label = label.cuda(non_blocking=True)

            logits3 = model.forward(img)
            loss = loss_func(logits3, label)

            opt.zero_grad()
            loss.backward()
            opt.step()

            train_result['cls3'].append(
                torch.argmax(logits3, dim=-1).cpu().numpy() ==
                label.cpu().numpy())
            train_loss += loss.item()

        log_value('train_acc3', np.mean(np.concatenate(train_result['cls3'])),
                  epoch)

        model = model.eval()

        val_result = {'cls3': []}
        with torch.no_grad():
            for step, (img_id, img, label,
                       bbox) in enumerate(tqdm(val_dataloader)):
                if args.cuda:
                    img = img.cuda()
                    label = label.cuda(non_blocking=True)

                logits3 = model.forward(img)

                val_result['cls3'].append(
                    torch.argmax(logits3, dim=-1).cpu().numpy() ==
                    label.cpu().numpy())

            log_value('val_acc3', np.mean(np.concatenate(val_result['cls3'])),
                      epoch)

            torch.save(model.module.state_dict(), args.save_model_path)
            print('epoch:{} loss:{} lr:{} train_acc:{} val_acc:{}'.format(
                epoch, train_loss, args.lr * (args.lr_decay**epoch),
                np.mean(np.concatenate(train_result['cls3'])),
                np.mean(np.concatenate(val_result['cls3']))))
            if np.mean(np.concatenate(val_result['cls3'])) > args.best_acc:
                torch.save(model.module.state_dict(),
                           args.save_best_model_path)
                args.best_acc = np.mean(np.concatenate(val_result['cls3']))
                print('weights updated')