def __init__(self, args): self.epochs = args.epochs self.bs = args.bs self.lr = args.lr self.wd = args.wd self.img_size = args.img_size self.aug = args.aug self.n_worker = args.n_worker self.test_interval = args.test_interval self.save_interval = args.save_interval self.save_opt = args.save_opt self.log_interval = args.log_interval self.res_mod_path = args.res_mod self.res_opt_path = args.res_opt self.use_gpu = args.use_gpu self.alpha_sal = args.alpha_sal self.wbce_w0 = args.wbce_w0 self.wbce_w1 = args.wbce_w1 self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1)) print('Models would be saved at : {}\n'.format(self.model_path)) if not os.path.exists(os.path.join(self.model_path, 'weights')): os.makedirs(os.path.join(self.model_path, 'weights')) if not os.path.exists(os.path.join(self.model_path, 'optimizers')): os.makedirs(os.path.join(self.model_path, 'optimizers')) if torch.cuda.is_available(): self.device = torch.device(device='cuda') else: self.device = torch.device(device='cpu') self.model = SODModel() self.model.to(self.device) self.criterion = EdgeSaliencyLoss(device=self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd) # Load model and optimizer if resumed if self.res_mod_path is not None: chkpt = torch.load(self.res_mod_path, map_location=self.device) self.model.load_state_dict(chkpt['model']) print("Resuming training with checkpoint : {}\n".format(self.res_mod_path)) if self.res_opt_path is not None: chkpt = torch.load(self.res_opt_path, map_location=self.device) self.optimizer.load_state_dict(chkpt['optimizer']) print("Resuming training with optimizer : {}\n".format(self.res_opt_path)) self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size) self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size) self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker) self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker)
def calculate_mae(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size) test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2) # List to save mean absolute error of each image mae_list = [] with torch.no_grad(): for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1): inp_imgs = inp_imgs.to(device) gt_masks = gt_masks.to(device) pred_masks, _ = model(inp_imgs) mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy() mae_list.extend(mae) print('MAE for the test set is :', np.mean(mae_list))
def run_inference(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size) # Since the images would be displayed to the user, the batch_size is set to 1 # Code at later point is also written assuming batch_size = 1, so do not change inf_dataloader = DataLoader(inf_data, batch_size=1, shuffle=True, num_workers=2) print("Press 'q' to quit.") with torch.no_grad(): for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1): img_tor = img_tor.to(device) pred_masks, _ = model(img_tor) # Assuming batch_size = 1 img_np = np.squeeze(img_np.numpy(), axis=0) img_np = img_np.astype(np.uint8) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1)) pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1)) print('Image :', batch_idx) cv2.imwrite( '/content/test/' + 'Input Image' + str(batch_idx) + '.png', img_np) cv2.imwrite( '/content/test/' + 'Generated Saliency Mask' + str(batch_idx) + '.png', pred_masks_raw) cv2.imwrite( '/content/test/' + 'Rounded-off Saliency Mask' + str(batch_idx) + '.png', pred_masks_round) key = cv2.waitKey(0) if key == ord('q'): break
def run_inference(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size) # Since the images would be displayed to the user, the batch_size is set to 1 # Code at later point is also written assuming batch_size = 1, so do not change inf_dataloader = DataLoader(inf_data, batch_size=4, shuffle=False, num_workers=2) #print("Press 'q' to quit.") with torch.no_grad(): for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1): img_tor = img_tor.to(device) pred_masks, _ = model(img_tor) # Assuming batch_size = 1 #print(img_np.shape, pred_masks.shape) #img_np = np.squeeze(img_np.numpy(), axis=0) #img_np = img_np.astype(np.uint8) #img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(1)) * 255 print(pred_masks_raw.shape) #pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1)) print('Batch :', batch_idx) #cv2.imshow('Input Image', img_np) #cv2.imshow('Generated Saliency Mask', pred_masks_raw) for im_idx in range(pred_masks_raw.shape[0]): out_path = os.path.join(args.output_folder, str(batch_idx)+"-"+str(im_idx)+"-subject_4.png") print(out_path) cv2.imwrite(out_path, pred_masks_raw[im_idx])
class Engine: def __init__(self, args): self.epochs = args.epochs self.bs = args.bs self.lr = args.lr self.wd = args.wd self.img_size = args.img_size self.aug = args.aug self.n_worker = args.n_worker self.test_interval = args.test_interval self.save_interval = args.save_interval self.save_opt = args.save_opt self.log_interval = args.log_interval self.res_mod_path = args.res_mod self.res_opt_path = args.res_opt self.use_gpu = args.use_gpu self.alpha_sal = args.alpha_sal self.wbce_w0 = args.wbce_w0 self.wbce_w1 = args.wbce_w1 self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1)) print('Models would be saved at : {}\n'.format(self.model_path)) if not os.path.exists(os.path.join(self.model_path, 'weights')): os.makedirs(os.path.join(self.model_path, 'weights')) if not os.path.exists(os.path.join(self.model_path, 'optimizers')): os.makedirs(os.path.join(self.model_path, 'optimizers')) if torch.cuda.is_available(): self.device = torch.device(device='cuda') else: self.device = torch.device(device='cpu') self.model = SODModel() self.model.to(self.device) self.criterion = EdgeSaliencyLoss(device=self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd) # Load model and optimizer if resumed if self.res_mod_path is not None: chkpt = torch.load(self.res_mod_path, map_location=self.device) self.model.load_state_dict(chkpt['model']) print("Resuming training with checkpoint : {}\n".format(self.res_mod_path)) if self.res_opt_path is not None: chkpt = torch.load(self.res_opt_path, map_location=self.device) self.optimizer.load_state_dict(chkpt['optimizer']) print("Resuming training with optimizer : {}\n".format(self.res_opt_path)) self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size) self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size) self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker) self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker) def train(self): best_test_mae = float('inf') for epoch in range(self.epochs): self.model.train() for batch_idx, (inp_imgs, gt_masks) in enumerate(self.train_dataloader): inp_imgs = inp_imgs.to(self.device) gt_masks = gt_masks.to(self.device) self.optimizer.zero_grad() pred_masks, ca_act_reg = self.model(inp_imgs) loss = self.criterion(pred_masks, gt_masks) + ca_act_reg # Activity regularizer from Channel-wise Att. loss.backward() self.optimizer.step() if batch_idx % self.log_interval == 0: print('TRAIN :: Epoch : {}\tBatch : {}/{} ({:.2f}%)\t\tTot Loss : {:.4f}\tReg : {:.4f}' .format(epoch + 1, batch_idx + 1, len(self.train_dataloader), (batch_idx + 1) * 100 / len(self.train_dataloader), loss.item(), ca_act_reg)) # Validation if epoch % self.test_interval == 0 or epoch % self.save_interval == 0: te_avg_loss, te_acc, te_pre, te_rec, te_mae = self.test() mod_chkpt = {'epoch': epoch, 'test_mae' : float(te_mae), 'model' : self.model.state_dict(), 'test_loss': float(te_avg_loss), 'test_acc': float(te_acc), 'test_pre': float(te_pre), 'test_rec': float(te_rec)} if self.save_opt: opt_chkpt = {'epoch': epoch, 'test_mae' : float(te_mae), 'optimizer': self.optimizer.state_dict(), 'test_loss': float(te_avg_loss), 'test_acc': float(te_acc), 'test_pre': float(te_pre), 'test_rec': float(te_rec)} # Save the best model if te_mae < best_test_mae: best_test_mae = te_mae torch.save(mod_chkpt, self.model_path + 'weights/best-model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'. format(epoch, best_test_mae, te_avg_loss)) if self.save_opt: torch.save(opt_chkpt, self.model_path + 'optimizers/best-opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'. format(epoch, best_test_mae, te_avg_loss)) print('Best Model Saved !!!\n') continue # Save model at regular intervals if self.save_interval is not None and epoch % self.save_interval == 0: torch.save(mod_chkpt, self.model_path + 'weights/model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'. format(epoch, te_mae, te_avg_loss)) if self.save_opt: torch.save(opt_chkpt, self.model_path + 'optimizers/opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'. format(epoch, best_test_mae, te_avg_loss)) print('Model Saved !!!\n') continue print('\n') def test(self): self.model.eval() tot_loss = 0 tp_fp = 0 # TruePositive + TrueNegative, for accuracy tp = 0 # TruePositive pred_true = 0 # Number of '1' predictions, for precision gt_true = 0 # Number of '1's in gt mask, for recall mae_list = [] # List to save mean absolute error of each image with torch.no_grad(): for batch_idx, (inp_imgs, gt_masks) in enumerate(self.test_dataloader, start=1): inp_imgs = inp_imgs.to(self.device) gt_masks = gt_masks.to(self.device) pred_masks, ca_act_reg = self.model(inp_imgs) loss = self.criterion(pred_masks, gt_masks) + ca_act_reg tot_loss += loss.item() tp_fp += (pred_masks.round() == gt_masks).float().sum() tp += torch.mul(pred_masks.round(), gt_masks).sum() pred_true += pred_masks.round().sum() gt_true += gt_masks.sum() # Record the absolute errors ae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy() mae_list.extend(ae) avg_loss = tot_loss / batch_idx accuracy = tp_fp / (len(self.test_data) * self.img_size * self.img_size) precision = tp / pred_true recall = tp / gt_true mae = np.mean(mae_list) print('TEST :: MAE : {:.4f}\tACC : {:.4f}\tPRE : {:.4f}\tREC : {:.4f}\tAVG-LOSS : {:.4f}\n'.format(mae, accuracy, precision, recall, avg_loss)) return avg_loss, accuracy, precision, recall, mae
parser.add_argument('--use_gpu', action="store_true", help='Whether to use GPU or not') parser.add_argument('--no_activation', action="store_true", help='Whether to use activation function before output') return parser.parse_args() if __name__ == '__main__': args = parse_arguments() # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel(last_activation = not args.no_activation) chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() batch_size = 1 x = torch.randn(batch_size, 3, 256, 256, requires_grad=True) torch_out = model(x) torch.onnx.export( model, # model being run x, # model input (or a tuple for multiple inputs) args.output, # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=11, # the ONNX version to export the model to