def __init__(self, args): self.epochs = args.epochs self.bs = args.bs self.lr = args.lr self.wd = args.wd self.img_size = args.img_size self.aug = args.aug self.n_worker = args.n_worker self.test_interval = args.test_interval self.save_interval = args.save_interval self.save_opt = args.save_opt self.log_interval = args.log_interval self.res_mod_path = args.res_mod self.res_opt_path = args.res_opt self.use_gpu = args.use_gpu self.alpha_sal = args.alpha_sal self.wbce_w0 = args.wbce_w0 self.wbce_w1 = args.wbce_w1 self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1)) print('Models would be saved at : {}\n'.format(self.model_path)) if not os.path.exists(os.path.join(self.model_path, 'weights')): os.makedirs(os.path.join(self.model_path, 'weights')) if not os.path.exists(os.path.join(self.model_path, 'optimizers')): os.makedirs(os.path.join(self.model_path, 'optimizers')) if torch.cuda.is_available(): self.device = torch.device(device='cuda') else: self.device = torch.device(device='cpu') self.model = SODModel() self.model.to(self.device) self.criterion = EdgeSaliencyLoss(device=self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd) # Load model and optimizer if resumed if self.res_mod_path is not None: chkpt = torch.load(self.res_mod_path, map_location=self.device) self.model.load_state_dict(chkpt['model']) print("Resuming training with checkpoint : {}\n".format(self.res_mod_path)) if self.res_opt_path is not None: chkpt = torch.load(self.res_opt_path, map_location=self.device) self.optimizer.load_state_dict(chkpt['optimizer']) print("Resuming training with optimizer : {}\n".format(self.res_opt_path)) self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size) self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size) self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker) self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker)
def calculate_mae(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size) test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2) # List to save mean absolute error of each image mae_list = [] with torch.no_grad(): for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1): inp_imgs = inp_imgs.to(device) gt_masks = gt_masks.to(device) pred_masks, _ = model(inp_imgs) mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy() mae_list.extend(mae) print('MAE for the test set is :', np.mean(mae_list))