def calculate_mae(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size)
    test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2)

    # List to save mean absolute error of each image
    mae_list = []
    with torch.no_grad():
        for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1):
            inp_imgs = inp_imgs.to(device)
            gt_masks = gt_masks.to(device)
            pred_masks, _ = model(inp_imgs)

            mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy()
            mae_list.extend(mae)

    print('MAE for the test set is :', np.mean(mae_list))
def save_pred(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    inf_data = InfDataloader(img_folder=args.imgs_folder,
                             target_size=args.img_size)
    # Since the images would be displayed to the user, the batch_size is set to 1
    # Code at later point is also written assuming batch_size = 1, so do not change
    inf_dataloader = DataLoader(inf_data,
                                batch_size=1,
                                shuffle=True,
                                num_workers=2)

    #directory to save the predictions
    pred_dir = './data/MSD/test/pred'
    if not os.path.exists(pred_dir):
        os.mkdir(pred_dir)

    with torch.no_grad():
        for batch_idx, (img_np, img_tor, img_name,
                        hw) in enumerate(tqdm.tqdm(inf_dataloader), start=1):
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

            # Assuming batch_size = 1
            img_np = np.squeeze(img_np.numpy(), axis=0)
            img_np = img_np.astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

            if args.raw is True:
                pred_masks = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1))
            else:
                pred_masks = np.squeeze(pred_masks.round().cpu().numpy(),
                                        axis=(0, 1))

            h, w = [int(x) for x in hw[0].split(' ')]

            s = max(h, w)
            pred_masks *= 255
            pred_masks = cv2.resize(pred_masks, (s, s),
                                    interpolation=cv2.INTER_AREA)

            offset_h = round((s - h) / 2)
            offset_w = round((s - w) / 2)
            p0, p1, p2, p3 = offset_h, s - offset_h, offset_w, s - offset_w
            pred_masks = pred_masks[p0:p1, p2:p3]

            cv2.imwrite(os.path.join(pred_dir, img_name[0] + '.png'),
                        pred_masks)
def run_inference(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    inf_data = InfDataloader(img_folder=args.imgs_folder,
                             target_size=args.img_size)
    # Since the images would be displayed to the user, the batch_size is set to 1
    # Code at later point is also written assuming batch_size = 1, so do not change
    inf_dataloader = DataLoader(inf_data,
                                batch_size=1,
                                shuffle=True,
                                num_workers=2)

    print("Press 'q' to quit.")
    with torch.no_grad():
        for batch_idx, (img_np, img_tor, img_name,
                        _) in enumerate(inf_dataloader, start=1):
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

            # Assuming batch_size = 1
            img_np = np.squeeze(img_np.numpy(), axis=0)
            img_np = img_np.astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1))
            pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(),
                                          axis=(0, 1))

            print('Image :', batch_idx)
            cv2.imshow('Input Image', img_np)
            cv2.imshow('Generated Saliency Mask', pred_masks_raw)
            cv2.imshow('Rounded-off Saliency Mask', pred_masks_round)

            print(img_name)

            key = cv2.waitKey(0)
            if key == ord('q'):
                break
Пример #4
0
class Engine:
    def __init__(self, args):
        self.epochs = args.epochs
        self.bs = args.bs
        self.lr = args.lr
        self.wd = args.wd
        self.img_size = args.img_size
        self.aug = args.aug
        self.n_worker = args.n_worker
        self.test_interval = args.test_interval
        self.save_interval = args.save_interval
        self.log_interval = args.log_interval
        self.resume_chkpt = args.resume
        self.use_gpu = args.use_gpu

        self.alpha_sal = args.alpha_sal
        self.wbce_w0 = args.wbce_w0
        self.wbce_w1 = args.wbce_w1

        self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1))
        print("Models would be saved at : {}\n".format(self.model_path))
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        if torch.cuda.is_available():
            self.device = torch.device(device='cuda')
        else:
            self.device = torch.device(device='cpu')

        self.model = SODModel()
        self.model.to(self.device)
        self.criterion = EdgeSaliencyLoss(device=self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd)

        # Load model and optimizer if resumed
        if self.resume_chkpt is not None:
            chkpt = torch.load(self.resume_chkpt, map_location=self.device)
            self.model.load_state_dict(chkpt['model'])
            self.optimizer.load_state_dict(chkpt['optimizer'])
            print("Resuming training from model : {}\n".format(self.resume_chkpt))

        self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size)
        self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size)
        self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker)
        self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker)

    def train(self):
        best_test_mae = float('inf')
        for epoch in range(self.epochs):
            self.model.train()
            for batch_idx, (inp_imgs, gt_masks) in enumerate(self.train_dataloader):
                inp_imgs = inp_imgs.to(self.device)
                gt_masks = gt_masks.to(self.device)

                self.optimizer.zero_grad()
                pred_masks, ca_act_reg = self.model(inp_imgs)
                loss = self.criterion(pred_masks, gt_masks) + ca_act_reg  # Activity regularizer from Channel-wise Att.

                loss.backward()
                self.optimizer.step()

                if batch_idx % self.log_interval == 0:
                    print('TRAIN :: Epoch : {}\tBatch : {}/{} ({:.2f}%)\t\tTot Loss : {:.4f}\tReg : {:.4f}'
                          .format(epoch + 1,
                                  batch_idx + 1, len(self.train_dataloader),
                                  (batch_idx + 1) * 100 / len(self.train_dataloader),
                                  loss.item(),
                                  ca_act_reg))

            # Validation
            if epoch % self.test_interval == 0 or epoch % self.save_interval == 0:
                te_avg_loss, te_acc, te_pre, te_rec, te_mae = self.test()
                chkpt = {'epoch': epoch,
                        'test_mae' : te_mae,
                        'model' : self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                        'test_loss': te_avg_loss,
                        'test_acc': te_acc,
                        'test_pre': te_pre,
                        'test_rec': te_rec}

                # Save the best model
                if te_mae < best_test_mae:
                    best_test_mae = te_mae
                    torch.save(chkpt, self.model_path + '/best_epoch-{}_mae-{:.4f}_loss-{:.4f}'.
                               format(epoch, best_test_mae, te_avg_loss) + '.pth')
                    print('Best Model Saved !!!\n')
                    continue
                
                # Save model at regular intervals
                if epoch % self.save_interval == 0:
                    torch.save(chkpt, self.model_path + '/model_epoch-{}_mae-{:.4f}_loss-{:.4f}'.
                               format(epoch, te_mae, te_avg_loss) + '.pth')
                    print('Model Saved !!!\n')
                    continue
            print('\n')

    def test(self):
        self.model.eval()

        tot_loss = 0
        tp_fp = 0   # TruePositive + TrueNegative, for accuracy
        tp = 0      # TruePositive
        pred_true = 0   # Number of '1' predictions, for precision
        gt_true = 0     # Number of '1's in gt mask, for recall
        mae_list = []   # List to save mean absolute error of each image

        with torch.no_grad():
            for batch_idx, (inp_imgs, gt_masks) in enumerate(self.test_dataloader, start=1):
                inp_imgs = inp_imgs.to(self.device)
                gt_masks = gt_masks.to(self.device)

                pred_masks, ca_act_reg = self.model(inp_imgs)
                loss = self.criterion(pred_masks, gt_masks) + ca_act_reg

                tot_loss += loss.item()

                tp_fp += (pred_masks.round() == gt_masks).float().sum()
                tp += torch.mul(pred_masks.round(), gt_masks).sum()
                pred_true += pred_masks.round().sum()
                gt_true += gt_masks.sum()

                # Record the absolute errors
                ae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy()
                mae_list.extend(ae)

        avg_loss = tot_loss / batch_idx
        accuracy = tp_fp / (len(self.test_data) * self.img_size * self.img_size)
        precision = tp / pred_true
        recall = tp / gt_true
        mae = np.mean(mae_list)

        print('TEST :: MAE : {:.4f}\tACC : {:.4f}\tPRE : {:.4f}\tREC : {:.4f}\tAVG-LOSS : {:.4f}\n'.format(mae,
                                                                                             accuracy,
                                                                                             precision,
                                                                                             recall,
                                                                                             avg_loss))

        return avg_loss, accuracy, precision, recall, mae
def visualize(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    eval_data = EvalDataLoader(img_folder=args.imgs_folder,
                               gt_path='./data/DUTS/DUTS-TE/DUTS-TE-Mask',
                               target_size=args.img_size)
    eval_dataloader = DataLoader(eval_data,
                                 batch_size=1,
                                 shuffle=True,
                                 num_workers=2)
    with torch.no_grad():
        for batch_idx, (img_np, img_tor, gt_mask) in enumerate(eval_dataloader,
                                                               start=1):
            gt_mask = np.squeeze(gt_mask.cpu().numpy(), axis=0)
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

            # Assuming batch_size = 1
            img_np = np.squeeze(img_np.numpy(), axis=0)
            img_np = img_np.astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1))
            pred_masks_raw = (pred_masks_raw * 255).round().astype(np.uint8)
            cv2.imshow('Input Image', img_np)
            cv2.imshow('Ground truth', gt_mask)
            cv2.imshow('Pyramid attention network', pred_masks_raw)
            calculate_auc(pred_masks_raw,
                          gt_mask,
                          plot=True,
                          model_name='Pyramid attention network')

            #CV2 saliency
            saliency_spectral = cv2.saliency.StaticSaliencySpectralResidual_create(
            )
            (success,
             saliencyMapSpectral) = saliency_spectral.computeSaliency(img_np)
            saliencyMapSpectral = (saliencyMapSpectral * 255).round().astype(
                np.uint8)
            cv2.imshow("Static (spectral residual)", saliencyMapSpectral)
            calculate_auc(saliencyMapSpectral,
                          gt_mask,
                          plot=True,
                          model_name='Static (spectral residual)')

            saliency_fg = cv2.saliency.StaticSaliencyFineGrained_create()
            (success, saliencyMapFG) = saliency_fg.computeSaliency(img_np)
            saliencyMapFG = (saliencyMapFG * 255).round().astype(np.uint8)
            cv2.imshow("Static (fine-grained)", saliencyMapFG)
            calculate_auc(saliencyMapFG,
                          gt_mask,
                          plot=True,
                          model_name='Static (fine-grained)')

            key = cv2.waitKey(0)
            if key == ord('q'):
                break
def compare_methods(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    print("here1")
    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()
    print("here2")
    eval_data = EvalDataLoader(img_folder=args.imgs_folder,
                               gt_path='./data/DUTS/DUTS-TE/DUTS-TE-Mask',
                               target_size=args.img_size)
    print("here3")
    eval_dataloader = DataLoader(eval_data,
                                 batch_size=1,
                                 shuffle=True,
                                 num_workers=2)
    print("here4")
    auc_pyramid, nss_pyramid, cc_pyramid, similarity_pyramid = 0, 0, 0, 0,
    auc_spectral, nss_spectral, cc_spectral, similarity_spectral = 0, 0, 0, 0
    auc_fg, nss_fg, cc_fg, similarity_fg = 0, 0, 0, 0
    count = 0
    with torch.no_grad():
        for _, (img_np, img_tor, gt_mask) in enumerate(eval_dataloader,
                                                       start=1):
            gt_mask = np.squeeze(gt_mask.cpu().numpy(), axis=0)
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)
            pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1))
            pred_masks_raw = (pred_masks_raw * 255).round().astype(np.uint8)
            auc_pyramid += calculate_auc(pred_masks_raw, gt_mask)
            nss_pyramid += nss(pred_masks_raw, gt_mask)
            cc_pyramid += cc(pred_masks_raw, gt_mask)
            similarity_pyramid += similarity(pred_masks_raw, gt_mask)

            img_np = np.squeeze(img_np.numpy(), axis=0)
            img_np = img_np.astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            saliency_spectral = cv2.saliency.StaticSaliencySpectralResidual_create(
            )
            (success,
             saliencyMapSpectral) = saliency_spectral.computeSaliency(img_np)
            saliencyMapSpectral = (saliencyMapSpectral * 255).round().astype(
                np.uint8)
            auc_spectral += calculate_auc(saliencyMapSpectral, gt_mask)
            nss_spectral += nss(saliencyMapSpectral, gt_mask)
            cc_spectral += cc(saliencyMapSpectral, gt_mask)
            similarity_spectral += similarity(saliencyMapSpectral, gt_mask)

            saliency_fg = cv2.saliency.StaticSaliencyFineGrained_create()
            (success, saliencyMapFG) = saliency_fg.computeSaliency(img_np)
            saliencyMapFG = (saliencyMapFG * 255).round().astype(np.uint8)
            auc_fg += calculate_auc(saliencyMapFG, gt_mask)
            nss_fg += nss(saliencyMapFG, gt_mask)
            cc_fg += cc(saliencyMapFG, gt_mask)
            similarity_fg += similarity(saliencyMapFG, gt_mask)
            count += 1
            print(count)
            if (count > 100):
                break
    print('Pyramid attention network: Average area under ROC curve: %f' %
          (auc_pyramid / count))
    print('CV2 static saliency (spectral): Average area under ROC curve: %f' %
          (auc_spectral / count))
    print(
        'CV2 static saliency (fine-grained): Average area under ROC curve: %f'
        % (auc_fg / count))
    print(
        '*********************************************************************************'
    )
    print('Pyramid attention network: Normalized Scanpath Saliency: %f' %
          (nss_pyramid / count))
    print('CV2 static saliency (spectral): Normalized Scanpath Saliency: %f' %
          (nss_spectral / count))
    print(
        'CV2 static saliency (fine-grained): Normalized Scanpath Saliency: %f'
        % (nss_fg / count))
    print(
        '*********************************************************************************'
    )
    print('Pyramid attention network: Pearson’s Correlation Coefficient: %f' %
          (cc_pyramid / count))
    print(
        'CV2 static saliency (spectral): Pearson’s Correlation Coefficient: %f'
        % (cc_spectral / count))
    print(
        'CV2 static saliency (fine-grained): Pearson’s Correlation Coefficient: %f'
        % (cc_fg / count))
    print(
        '*********************************************************************************'
    )
    print('Pyramid attention network: SIM: %f' % (similarity_pyramid / count))
    print('CV2 static saliency (spectral): SIM: %f' %
          (similarity_spectral / count))
    print('CV2 static saliency (fine-grained): SIM: %f' %
          (similarity_fg / count))
    return auc_pyramid / count, auc_spectral / count, auc_fg / count