Example #1
0
def evaluation(opt):

    device = torch.device('cuda' if opt.cuda else 'cpu')

    alpha = opt.alpha
    beta = opt.beta
    gamma = opt.gamma
    cycle_num = opt.cycle_num

    #load test model
    #test_model = FSRCNN(opt.upscale)
    test_model = model(opt.upscale)
    state_dict = load_state_dict(opt.model_path)
    test_model.load_state_dict(state_dict)
    test_model = test_model.to(device)
    test_model.eval()

    #load baseline
    baseline_model = SRCNN(opt.upscale)
    baseline_dict = load_state_dict(opt.baseline_path)
    baseline_model.load_state_dict(baseline_dict)
    baseline_model = baseline_model.to(device)
    baseline_model.eval()
    #load dataset
    dataset = TestDataset(opt.HR_path, opt.LR_path)
    dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)
    crop_boarder = opt.upscale
    print("dd")
    baseline_psnr, baseline_ssim = sr_forward_psnr(dataloader, baseline_model, device, crop_boarder)
    print("baseline_psnr:"+ str(baseline_psnr))
    print("baseline_ssim:"+ str(baseline_ssim))
    test_psnr, test_ssim = sr_forward_psnr(dataloader, test_model, device, crop_boarder)
    print("test_psnr:"+ str(test_psnr))
    print("test_ssim:"+ str(test_ssim))
    baseline_times = 0.0
    test_times = 0.0
    # test_ssim = 0.0
    # test_psnr = 0.0
    # baseline_ssim = 0.0
    # baseline_psnr = 0.0
    for index in range(cycle_num):
        baseline_time = sr_forward_time(dataloader, baseline_model, device)
        baseline_times += baseline_time
        print("baseline time"+str(baseline_times))
    for index in range(cycle_num):
        test_time = sr_forward_time(dataloader, test_model, device)
        test_times += test_time
        print("test time"+str(test_times))


    score = alpha * (test_psnr-baseline_psnr) + beta * (test_ssim-baseline_ssim) + gamma * min(baseline_times/test_times, 4)

    print('psnr: {:.4f}'.format(alpha * (test_psnr-baseline_psnr)))
    print('ssim: {:.4f}'.format(beta * (test_ssim-baseline_ssim)))
    print('time: {:.4f}'.format(gamma * min((baseline_times/test_times), 4)))
    print('score: {:.4f}'.format(score))

    print('avarage score: {:.4f}'.format(score))

    #calc FLOPs
    width = 360
    height = 240
    flops, params = profile(test_model, input_size=(1, 3, height, width))
    print('test_model{} x {}, flops: {:.4f} GFLOPs, params: {}'.format(height, width, flops/(1e9), params))
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(
        '\n\t\t\t\t Aum Sri Sai Ram\nFER on FERPLUS using Local and global Attention along with region branch (non-overlapping patches)\n\n'
    )
    print(args)
    print('\ntrain rule: ', args.train_rule, ' and loss type: ',
          args.loss_type, '\n')

    print('\n lr is : ', args.lr)

    print('img_dir:', args.root_path)

    print('\nTraining mode: ', args.mode)

    mean = [0.5, 0.5, 0.5]
    std = [0.5, 0.5, 0.5]
    imagesize = args.imagesize
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.3,
                               saturation=0.25,
                               hue=0.05),
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    valid_transform = transforms.Compose([
        transforms.Resize((args.imagesize, args.imagesize)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_dataset = ImageList(root=args.root_path +
                              'Images/FER2013TrainValid/',
                              fileList=args.train_list,
                              transform=train_transform,
                              mode=args.mode)

    test_data = ImageList(root=args.root_path + 'Images/FER2013Test/',
                          fileList=args.test_list,
                          transform=valid_transform,
                          mode=args.mode)

    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=args.batch_size_t,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=True)

    cls_num_list = train_dataset.get_cls_num_list()
    print('Train split class wise is :', cls_num_list)

    if args.train_rule == 'None':
        train_sampler = None
        per_cls_weights = None
    elif args.train_rule == 'Resample':
        train_sampler = ImbalancedDatasetSampler(train_dataset)
        per_cls_weights = None
    elif args.train_rule == 'Reweight':
        train_sampler = None
        beta = 0.9999
        effective_num = 1.0 - np.power(beta, cls_num_list)
        per_cls_weights = (1.0 - beta) / np.array(effective_num)
        per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(
            cls_num_list)
        per_cls_weights = torch.FloatTensor(per_cls_weights).to(device)

    if args.loss_type == 'CE':
        criterion = nn.CrossEntropyLoss(weight=per_cls_weights).to(device)
    elif args.loss_type == 'Focal':
        criterion = FocalLoss(weight=per_cls_weights, gamma=2).to(device)
    else:
        warnings.warn('Loss type is not listed')
        return

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    print('length of  train+valid Database for training: ' +
          str(len(train_loader.dataset)))

    print('length of  test Database: ' + str(len(test_loader.dataset)))

    # prepare model
    basemodel = resnet50(pretrained=False)
    attention_model = AttentionBranch(inputdim=512,
                                      num_regions=args.num_attentive_regions,
                                      num_classes=args.num_classes)
    region_model = RegionBranch(inputdim=1024,
                                num_regions=args.num_regions,
                                num_classes=args.num_classes)

    basemodel = torch.nn.DataParallel(basemodel).to(device)
    attention_model = torch.nn.DataParallel(attention_model).to(device)
    region_model = torch.nn.DataParallel(region_model).to(device)

    print('\nNumber of parameters:')
    print(
        'Base Model: {}, Attention Branch:{}, Region Branch:{} and Total: {}'.
        format(
            count_parameters(basemodel), count_parameters(attention_model),
            count_parameters(region_model),
            count_parameters(basemodel) + count_parameters(attention_model) +
            count_parameters(region_model)))

    optimizer = torch.optim.SGD([{
        "params": basemodel.parameters(),
        "lr": 0.0001,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }])

    optimizer.add_param_group({
        "params": attention_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })
    optimizer.add_param_group({
        "params": region_model.parameters(),
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    })

    if args.pretrained:

        util.load_state_dict(
            basemodel, 'pretrainedmodels/vgg_msceleb_resnet50_ft_weight.pkl')

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            basemodel.load_state_dict(checkpoint['base_state_dict'])
            attention_model.load_state_dict(checkpoint['attention_state_dict'])
            region_model.load_state_dict(checkpoint['region_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    print('\nTraining starting:\n')
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, basemodel, attention_model, region_model,
              criterion, optimizer, epoch)

        adjust_learning_rate(optimizer, epoch)
        prec1 = validate(test_loader, basemodel, attention_model, region_model,
                         criterion, epoch)
        print("Epoch: {}   Test Acc: {}".format(epoch, prec1))
        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1

        best_prec1 = max(prec1.to(device).item(), best_prec1)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'base_state_dict': basemodel.state_dict(),
                'attention_state_dict': attention_model.state_dict(),
                'region_state_dict': region_model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best.item())
Example #3
0
                                CONFIG.dataset_dir,
                                CONFIG.labels_name,
                                transforms=val_transform)

        train_loader, val_loader, test_loader = get_dataloader(
            train_data, val_data, val_data, CONFIG)

        model = Model(input_size=CONFIG.input_size,
                      classes=CONFIG.classes,
                      se=True,
                      activation="hswish",
                      l_cfgs_name=CONFIG.model,
                      seg_state=CONFIG.seg_state)

        if args.load_pretrained:
            pretrained_dict = load_state_dict(CONFIG.model_pretrained,
                                              use_ema=CONFIG.ema)
            model.load_state_dict(pretrained_dict, strict=False)
            logging.info("Load pretrained from {} to {}".format(
                CONFIG.model_pretrained, CONFIG.model))

        if (device.type == "cuda" and CONFIG.ngpu >= 1):
            model = model.to(device)
            model = nn.DataParallel(model, list(range(CONFIG.ngpu)))

        optimizer = get_optimizer(model.parameters(), CONFIG.optim_state)
        criterion = Loss(device, CONFIG)
        scheduler = get_lr_scheduler(optimizer, len(train_loader), CONFIG)

        start_time = time.time()
        trainer = Trainer(criterion, optimizer, scheduler, device, CONFIG)
        trainer.train_loop(train_loader, test_loader, model, fold)
Example #4
0
def evaluation(opt):

    device = torch.device('cuda' if opt.cuda else 'cpu')

    alpha = opt.alpha
    beta = opt.beta
    gamma = opt.gamma
    cycle_num = opt.cycle_num

    crop_boarder = opt.upscale

    # load dataset
    dataset = TestDataset(opt.HR_path, opt.LR_path)
    dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=False)
    test_times = 0.0

    module = importlib.import_module('model.{}'.format(opt.test_model))
    test_model = module.model(opt.upscale)
    state_dict = load_state_dict('pre-train-model/{}.pth'.format(
        opt.test_model))
    test_model.load_state_dict(state_dict)
    test_model = test_model.to(device)
    test_model.eval()

    # load baseline
    module = importlib.import_module('model.{}'.format('baseline'))
    baseline_model = module.model(opt.upscale)
    baseline_dict = load_state_dict(opt.baseline_path)
    baseline_model.load_state_dict(baseline_dict)
    baseline_model = baseline_model.to(device)
    baseline_model.eval()

    #calc FLOPs
    width = 360
    height = 240

    inputs = torch.randn(1, 3, height, width).to(device)
    macs = profile(test_model.to('cuda'), inputs)
    print('{:.4f} G'.format(macs / 1e9))
    if (macs / 1e9) > 2.0:
        print('model GFLOPs is more than 2\n')
        exit(-1)

    save_path = 'results/{}'.format(opt.test_model)
    if not os.path.exists(save_path):
        os.mkdir(save_path)
        os.mkdir(os.path.join(save_path, 'SR'))
        os.mkdir(os.path.join(save_path, 'GT'))

    test_psnr, test_ssim = sr_forward_psnr(dataloader, test_model, device,
                                           crop_boarder, save_path)
    baseline_psnr, baseline_ssim = sr_forward_psnr(dataloader, baseline_model,
                                                   device, crop_boarder)
    time_scores = 0
    test_times = 0
    baseline_times = 0
    for index in range(cycle_num):

        test_time = sr_forward_time(dataloader, test_model, device)
        test_times += test_time
        baseline_time = sr_forward_time(dataloader, baseline_model, device)
        baseline_times += baseline_time

        #time_score = gamma * min((baseline_time - test_time) / baseline_time, 2)

        #time_scores += time_score

    #avg_time_score = (time_scores / cycle_num)
    avg_test_time = (test_times / cycle_num) // 100
    avg_baseline_time = (baseline_times / cycle_num) // 100
    avg_time_score = gamma * min(
        (avg_baseline_time - avg_test_time) / avg_baseline_time, 2)

    score = alpha * (test_psnr - baseline_psnr) + beta * (
        test_ssim - baseline_ssim) + avg_time_score

    print('model: {}'.format(opt.test_model))
    print('test model: {:.4f}, base model: {:.4f}, psnr: {:.4f}'.format(
        test_psnr, baseline_psnr, alpha * (test_psnr - baseline_psnr)))
    print('test model: {:.4f}, base model: {:.4f}, ssim: {:.4f}'.format(
        test_ssim, baseline_ssim, beta * (test_ssim - baseline_ssim)))
    print(
        'test model: {:.4f} ms, base model: {:.4f} ms, time: {:.4f} ms'.format(
            avg_test_time * 100, avg_baseline_time * 100, avg_time_score))
    print('score: {:.4f}'.format(score))
Example #5
0
    def init_weights(self, pretrained_model_path):
        util.load_state_dict(self.pretrained_model, pretrained_model_path)

        torch.nn.init.zeros_(self.fc.weight)
        torch.nn.init.zeros_(self.fc.bias)
Example #6
0
    test_csv_path = os.path.join(CONFIG.dataset_dir, "test.csv")
    test_dataset = TestMangoDataset(test_csv_path, test_root_path, transforms=test_transform)
    test_loader = torch.utils.data.DataLoader(
                        test_dataset,
                        batch_size=CONFIG.batch_size,
                        pin_memory=True,
                        num_workers=CONFIG.num_workers,
                        shuffle=False
                    )

    set_random_seed(CONFIG.seed)

    tta_pred_labels = []
    for i, path in enumerate([CONFIG.path_to_save_model[:-4]+"_{}".format(i)+CONFIG.path_to_save_model[-4:] for i in range(5)]):
        model = Model(input_size=CONFIG.input_size, classes=CONFIG.classes, se=True, activation="hswish", l_cfgs_name=CONFIG.model, seg_state=CONFIG.seg_state)
        pretrained_dict = load_state_dict(path)
        model.load_state_dict(pretrained_dict["model"], strict=False)

        if (device.type == "cuda" and CONFIG.ngpu >= 1):
            model = model.to(device)
            model = nn.DataParallel(model, list(range(CONFIG.ngpu)))
        model.module.set_state(False)

        with torch.no_grad():
            for t in range(args.tta):
                pred_labels = []
                model.eval()
                for step, X in enumerate(test_loader):
                    X = X["image"]
                    X = X.to(device, non_blocking=True)