Пример #1
0
    def __init__(self, args):
        self.epochs = args.epochs
        self.bs = args.bs
        self.lr = args.lr
        self.wd = args.wd
        self.img_size = args.img_size
        self.aug = args.aug
        self.n_worker = args.n_worker
        self.test_interval = args.test_interval
        self.save_interval = args.save_interval
        self.save_opt = args.save_opt
        self.log_interval = args.log_interval
        self.res_mod_path = args.res_mod
        self.res_opt_path = args.res_opt
        self.use_gpu = args.use_gpu

        self.alpha_sal = args.alpha_sal
        self.wbce_w0 = args.wbce_w0
        self.wbce_w1 = args.wbce_w1

        self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1))
        print('Models would be saved at : {}\n'.format(self.model_path))
        if not os.path.exists(os.path.join(self.model_path, 'weights')):
            os.makedirs(os.path.join(self.model_path, 'weights'))
        if not os.path.exists(os.path.join(self.model_path, 'optimizers')):
            os.makedirs(os.path.join(self.model_path, 'optimizers'))

        if torch.cuda.is_available():
            self.device = torch.device(device='cuda')
        else:
            self.device = torch.device(device='cpu')

        self.model = SODModel()
        self.model.to(self.device)
        self.criterion = EdgeSaliencyLoss(device=self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd)

        # Load model and optimizer if resumed
        if self.res_mod_path is not None:
            chkpt = torch.load(self.res_mod_path, map_location=self.device)
            self.model.load_state_dict(chkpt['model'])
            print("Resuming training with checkpoint : {}\n".format(self.res_mod_path))
        if self.res_opt_path is not None:
            chkpt = torch.load(self.res_opt_path, map_location=self.device)
            self.optimizer.load_state_dict(chkpt['optimizer'])
            print("Resuming training with optimizer : {}\n".format(self.res_opt_path))

        self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size)
        self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size)
        self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker)
        self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker)
Пример #2
0
def calculate_mae(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size)
    test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2)

    # List to save mean absolute error of each image
    mae_list = []
    with torch.no_grad():
        for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1):
            inp_imgs = inp_imgs.to(device)
            gt_masks = gt_masks.to(device)
            pred_masks, _ = model(inp_imgs)

            mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy()
            mae_list.extend(mae)

    print('MAE for the test set is :', np.mean(mae_list))