def calculate_mae(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size) test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2) # List to save mean absolute error of each image mae_list = [] with torch.no_grad(): for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1): inp_imgs = inp_imgs.to(device) gt_masks = gt_masks.to(device) pred_masks, _ = model(inp_imgs) mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy() mae_list.extend(mae) print('MAE for the test set is :', np.mean(mae_list))
def save_pred(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size) # Since the images would be displayed to the user, the batch_size is set to 1 # Code at later point is also written assuming batch_size = 1, so do not change inf_dataloader = DataLoader(inf_data, batch_size=1, shuffle=True, num_workers=2) #directory to save the predictions pred_dir = './data/MSD/test/pred' if not os.path.exists(pred_dir): os.mkdir(pred_dir) with torch.no_grad(): for batch_idx, (img_np, img_tor, img_name, hw) in enumerate(tqdm.tqdm(inf_dataloader), start=1): img_tor = img_tor.to(device) pred_masks, _ = model(img_tor) # Assuming batch_size = 1 img_np = np.squeeze(img_np.numpy(), axis=0) img_np = img_np.astype(np.uint8) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) if args.raw is True: pred_masks = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1)) else: pred_masks = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1)) h, w = [int(x) for x in hw[0].split(' ')] s = max(h, w) pred_masks *= 255 pred_masks = cv2.resize(pred_masks, (s, s), interpolation=cv2.INTER_AREA) offset_h = round((s - h) / 2) offset_w = round((s - w) / 2) p0, p1, p2, p3 = offset_h, s - offset_h, offset_w, s - offset_w pred_masks = pred_masks[p0:p1, p2:p3] cv2.imwrite(os.path.join(pred_dir, img_name[0] + '.png'), pred_masks)
def run_inference(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size) # Since the images would be displayed to the user, the batch_size is set to 1 # Code at later point is also written assuming batch_size = 1, so do not change inf_dataloader = DataLoader(inf_data, batch_size=1, shuffle=True, num_workers=2) print("Press 'q' to quit.") with torch.no_grad(): for batch_idx, (img_np, img_tor, img_name, _) in enumerate(inf_dataloader, start=1): img_tor = img_tor.to(device) pred_masks, _ = model(img_tor) # Assuming batch_size = 1 img_np = np.squeeze(img_np.numpy(), axis=0) img_np = img_np.astype(np.uint8) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1)) pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1)) print('Image :', batch_idx) cv2.imshow('Input Image', img_np) cv2.imshow('Generated Saliency Mask', pred_masks_raw) cv2.imshow('Rounded-off Saliency Mask', pred_masks_round) print(img_name) key = cv2.waitKey(0) if key == ord('q'): break
class Engine: def __init__(self, args): self.epochs = args.epochs self.bs = args.bs self.lr = args.lr self.wd = args.wd self.img_size = args.img_size self.aug = args.aug self.n_worker = args.n_worker self.test_interval = args.test_interval self.save_interval = args.save_interval self.log_interval = args.log_interval self.resume_chkpt = args.resume self.use_gpu = args.use_gpu self.alpha_sal = args.alpha_sal self.wbce_w0 = args.wbce_w0 self.wbce_w1 = args.wbce_w1 self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1)) print("Models would be saved at : {}\n".format(self.model_path)) if not os.path.exists(self.model_path): os.makedirs(self.model_path) if torch.cuda.is_available(): self.device = torch.device(device='cuda') else: self.device = torch.device(device='cpu') self.model = SODModel() self.model.to(self.device) self.criterion = EdgeSaliencyLoss(device=self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd) # Load model and optimizer if resumed if self.resume_chkpt is not None: chkpt = torch.load(self.resume_chkpt, map_location=self.device) self.model.load_state_dict(chkpt['model']) self.optimizer.load_state_dict(chkpt['optimizer']) print("Resuming training from model : {}\n".format(self.resume_chkpt)) self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size) self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size) self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker) self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker) def train(self): best_test_mae = float('inf') for epoch in range(self.epochs): self.model.train() for batch_idx, (inp_imgs, gt_masks) in enumerate(self.train_dataloader): inp_imgs = inp_imgs.to(self.device) gt_masks = gt_masks.to(self.device) self.optimizer.zero_grad() pred_masks, ca_act_reg = self.model(inp_imgs) loss = self.criterion(pred_masks, gt_masks) + ca_act_reg # Activity regularizer from Channel-wise Att. loss.backward() self.optimizer.step() if batch_idx % self.log_interval == 0: print('TRAIN :: Epoch : {}\tBatch : {}/{} ({:.2f}%)\t\tTot Loss : {:.4f}\tReg : {:.4f}' .format(epoch + 1, batch_idx + 1, len(self.train_dataloader), (batch_idx + 1) * 100 / len(self.train_dataloader), loss.item(), ca_act_reg)) # Validation if epoch % self.test_interval == 0 or epoch % self.save_interval == 0: te_avg_loss, te_acc, te_pre, te_rec, te_mae = self.test() chkpt = {'epoch': epoch, 'test_mae' : te_mae, 'model' : self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'test_loss': te_avg_loss, 'test_acc': te_acc, 'test_pre': te_pre, 'test_rec': te_rec} # Save the best model if te_mae < best_test_mae: best_test_mae = te_mae torch.save(chkpt, self.model_path + '/best_epoch-{}_mae-{:.4f}_loss-{:.4f}'. format(epoch, best_test_mae, te_avg_loss) + '.pth') print('Best Model Saved !!!\n') continue # Save model at regular intervals if epoch % self.save_interval == 0: torch.save(chkpt, self.model_path + '/model_epoch-{}_mae-{:.4f}_loss-{:.4f}'. format(epoch, te_mae, te_avg_loss) + '.pth') print('Model Saved !!!\n') continue print('\n') def test(self): self.model.eval() tot_loss = 0 tp_fp = 0 # TruePositive + TrueNegative, for accuracy tp = 0 # TruePositive pred_true = 0 # Number of '1' predictions, for precision gt_true = 0 # Number of '1's in gt mask, for recall mae_list = [] # List to save mean absolute error of each image with torch.no_grad(): for batch_idx, (inp_imgs, gt_masks) in enumerate(self.test_dataloader, start=1): inp_imgs = inp_imgs.to(self.device) gt_masks = gt_masks.to(self.device) pred_masks, ca_act_reg = self.model(inp_imgs) loss = self.criterion(pred_masks, gt_masks) + ca_act_reg tot_loss += loss.item() tp_fp += (pred_masks.round() == gt_masks).float().sum() tp += torch.mul(pred_masks.round(), gt_masks).sum() pred_true += pred_masks.round().sum() gt_true += gt_masks.sum() # Record the absolute errors ae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy() mae_list.extend(ae) avg_loss = tot_loss / batch_idx accuracy = tp_fp / (len(self.test_data) * self.img_size * self.img_size) precision = tp / pred_true recall = tp / gt_true mae = np.mean(mae_list) print('TEST :: MAE : {:.4f}\tACC : {:.4f}\tPRE : {:.4f}\tREC : {:.4f}\tAVG-LOSS : {:.4f}\n'.format(mae, accuracy, precision, recall, avg_loss)) return avg_loss, accuracy, precision, recall, mae
def visualize(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() eval_data = EvalDataLoader(img_folder=args.imgs_folder, gt_path='./data/DUTS/DUTS-TE/DUTS-TE-Mask', target_size=args.img_size) eval_dataloader = DataLoader(eval_data, batch_size=1, shuffle=True, num_workers=2) with torch.no_grad(): for batch_idx, (img_np, img_tor, gt_mask) in enumerate(eval_dataloader, start=1): gt_mask = np.squeeze(gt_mask.cpu().numpy(), axis=0) img_tor = img_tor.to(device) pred_masks, _ = model(img_tor) # Assuming batch_size = 1 img_np = np.squeeze(img_np.numpy(), axis=0) img_np = img_np.astype(np.uint8) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1)) pred_masks_raw = (pred_masks_raw * 255).round().astype(np.uint8) cv2.imshow('Input Image', img_np) cv2.imshow('Ground truth', gt_mask) cv2.imshow('Pyramid attention network', pred_masks_raw) calculate_auc(pred_masks_raw, gt_mask, plot=True, model_name='Pyramid attention network') #CV2 saliency saliency_spectral = cv2.saliency.StaticSaliencySpectralResidual_create( ) (success, saliencyMapSpectral) = saliency_spectral.computeSaliency(img_np) saliencyMapSpectral = (saliencyMapSpectral * 255).round().astype( np.uint8) cv2.imshow("Static (spectral residual)", saliencyMapSpectral) calculate_auc(saliencyMapSpectral, gt_mask, plot=True, model_name='Static (spectral residual)') saliency_fg = cv2.saliency.StaticSaliencyFineGrained_create() (success, saliencyMapFG) = saliency_fg.computeSaliency(img_np) saliencyMapFG = (saliencyMapFG * 255).round().astype(np.uint8) cv2.imshow("Static (fine-grained)", saliencyMapFG) calculate_auc(saliencyMapFG, gt_mask, plot=True, model_name='Static (fine-grained)') key = cv2.waitKey(0) if key == ord('q'): break
def compare_methods(args): # Determine device if args.use_gpu and torch.cuda.is_available(): device = torch.device(device='cuda') else: device = torch.device(device='cpu') print("here1") # Load model model = SODModel() chkpt = torch.load(args.model_path, map_location=device) model.load_state_dict(chkpt['model']) model.to(device) model.eval() print("here2") eval_data = EvalDataLoader(img_folder=args.imgs_folder, gt_path='./data/DUTS/DUTS-TE/DUTS-TE-Mask', target_size=args.img_size) print("here3") eval_dataloader = DataLoader(eval_data, batch_size=1, shuffle=True, num_workers=2) print("here4") auc_pyramid, nss_pyramid, cc_pyramid, similarity_pyramid = 0, 0, 0, 0, auc_spectral, nss_spectral, cc_spectral, similarity_spectral = 0, 0, 0, 0 auc_fg, nss_fg, cc_fg, similarity_fg = 0, 0, 0, 0 count = 0 with torch.no_grad(): for _, (img_np, img_tor, gt_mask) in enumerate(eval_dataloader, start=1): gt_mask = np.squeeze(gt_mask.cpu().numpy(), axis=0) img_tor = img_tor.to(device) pred_masks, _ = model(img_tor) pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1)) pred_masks_raw = (pred_masks_raw * 255).round().astype(np.uint8) auc_pyramid += calculate_auc(pred_masks_raw, gt_mask) nss_pyramid += nss(pred_masks_raw, gt_mask) cc_pyramid += cc(pred_masks_raw, gt_mask) similarity_pyramid += similarity(pred_masks_raw, gt_mask) img_np = np.squeeze(img_np.numpy(), axis=0) img_np = img_np.astype(np.uint8) img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) saliency_spectral = cv2.saliency.StaticSaliencySpectralResidual_create( ) (success, saliencyMapSpectral) = saliency_spectral.computeSaliency(img_np) saliencyMapSpectral = (saliencyMapSpectral * 255).round().astype( np.uint8) auc_spectral += calculate_auc(saliencyMapSpectral, gt_mask) nss_spectral += nss(saliencyMapSpectral, gt_mask) cc_spectral += cc(saliencyMapSpectral, gt_mask) similarity_spectral += similarity(saliencyMapSpectral, gt_mask) saliency_fg = cv2.saliency.StaticSaliencyFineGrained_create() (success, saliencyMapFG) = saliency_fg.computeSaliency(img_np) saliencyMapFG = (saliencyMapFG * 255).round().astype(np.uint8) auc_fg += calculate_auc(saliencyMapFG, gt_mask) nss_fg += nss(saliencyMapFG, gt_mask) cc_fg += cc(saliencyMapFG, gt_mask) similarity_fg += similarity(saliencyMapFG, gt_mask) count += 1 print(count) if (count > 100): break print('Pyramid attention network: Average area under ROC curve: %f' % (auc_pyramid / count)) print('CV2 static saliency (spectral): Average area under ROC curve: %f' % (auc_spectral / count)) print( 'CV2 static saliency (fine-grained): Average area under ROC curve: %f' % (auc_fg / count)) print( '*********************************************************************************' ) print('Pyramid attention network: Normalized Scanpath Saliency: %f' % (nss_pyramid / count)) print('CV2 static saliency (spectral): Normalized Scanpath Saliency: %f' % (nss_spectral / count)) print( 'CV2 static saliency (fine-grained): Normalized Scanpath Saliency: %f' % (nss_fg / count)) print( '*********************************************************************************' ) print('Pyramid attention network: Pearson’s Correlation Coefficient: %f' % (cc_pyramid / count)) print( 'CV2 static saliency (spectral): Pearson’s Correlation Coefficient: %f' % (cc_spectral / count)) print( 'CV2 static saliency (fine-grained): Pearson’s Correlation Coefficient: %f' % (cc_fg / count)) print( '*********************************************************************************' ) print('Pyramid attention network: SIM: %f' % (similarity_pyramid / count)) print('CV2 static saliency (spectral): SIM: %f' % (similarity_spectral / count)) print('CV2 static saliency (fine-grained): SIM: %f' % (similarity_fg / count)) return auc_pyramid / count, auc_spectral / count, auc_fg / count