Пример #1
0
    def __init__(self, args):
        self.epochs = args.epochs
        self.bs = args.bs
        self.lr = args.lr
        self.wd = args.wd
        self.img_size = args.img_size
        self.aug = args.aug
        self.n_worker = args.n_worker
        self.test_interval = args.test_interval
        self.save_interval = args.save_interval
        self.save_opt = args.save_opt
        self.log_interval = args.log_interval
        self.res_mod_path = args.res_mod
        self.res_opt_path = args.res_opt
        self.use_gpu = args.use_gpu

        self.alpha_sal = args.alpha_sal
        self.wbce_w0 = args.wbce_w0
        self.wbce_w1 = args.wbce_w1

        self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1))
        print('Models would be saved at : {}\n'.format(self.model_path))
        if not os.path.exists(os.path.join(self.model_path, 'weights')):
            os.makedirs(os.path.join(self.model_path, 'weights'))
        if not os.path.exists(os.path.join(self.model_path, 'optimizers')):
            os.makedirs(os.path.join(self.model_path, 'optimizers'))

        if torch.cuda.is_available():
            self.device = torch.device(device='cuda')
        else:
            self.device = torch.device(device='cpu')

        self.model = SODModel()
        self.model.to(self.device)
        self.criterion = EdgeSaliencyLoss(device=self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd)

        # Load model and optimizer if resumed
        if self.res_mod_path is not None:
            chkpt = torch.load(self.res_mod_path, map_location=self.device)
            self.model.load_state_dict(chkpt['model'])
            print("Resuming training with checkpoint : {}\n".format(self.res_mod_path))
        if self.res_opt_path is not None:
            chkpt = torch.load(self.res_opt_path, map_location=self.device)
            self.optimizer.load_state_dict(chkpt['optimizer'])
            print("Resuming training with optimizer : {}\n".format(self.res_opt_path))

        self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size)
        self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size)
        self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker)
        self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker)
Пример #2
0
def calculate_mae(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    test_data = SODLoader(mode='test', augment_data=False, target_size=args.img_size)
    test_dataloader = DataLoader(test_data, batch_size=args.bs, shuffle=False, num_workers=2)

    # List to save mean absolute error of each image
    mae_list = []
    with torch.no_grad():
        for batch_idx, (inp_imgs, gt_masks) in enumerate(tqdm.tqdm(test_dataloader), start=1):
            inp_imgs = inp_imgs.to(device)
            gt_masks = gt_masks.to(device)
            pred_masks, _ = model(inp_imgs)

            mae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy()
            mae_list.extend(mae)

    print('MAE for the test set is :', np.mean(mae_list))
Пример #3
0
def run_inference(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    inf_data = InfDataloader(img_folder=args.imgs_folder,
                             target_size=args.img_size)
    # Since the images would be displayed to the user, the batch_size is set to 1
    # Code at later point is also written assuming batch_size = 1, so do not change
    inf_dataloader = DataLoader(inf_data,
                                batch_size=1,
                                shuffle=True,
                                num_workers=2)

    print("Press 'q' to quit.")
    with torch.no_grad():
        for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1):
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

            # Assuming batch_size = 1
            img_np = np.squeeze(img_np.numpy(), axis=0)
            img_np = img_np.astype(np.uint8)
            img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(0, 1))
            pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(),
                                          axis=(0, 1))

            print('Image :', batch_idx)
            cv2.imwrite(
                '/content/test/' + 'Input Image' + str(batch_idx) + '.png',
                img_np)
            cv2.imwrite(
                '/content/test/' + 'Generated Saliency Mask' + str(batch_idx) +
                '.png', pred_masks_raw)
            cv2.imwrite(
                '/content/test/' + 'Rounded-off Saliency Mask' +
                str(batch_idx) + '.png', pred_masks_round)

            key = cv2.waitKey(0)
            if key == ord('q'):
                break
Пример #4
0
def run_inference(args):
    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel()
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    inf_data = InfDataloader(img_folder=args.imgs_folder, target_size=args.img_size)
    # Since the images would be displayed to the user, the batch_size is set to 1
    # Code at later point is also written assuming batch_size = 1, so do not change
    inf_dataloader = DataLoader(inf_data, batch_size=4, shuffle=False, num_workers=2)

    #print("Press 'q' to quit.")
    with torch.no_grad():
        for batch_idx, (img_np, img_tor) in enumerate(inf_dataloader, start=1):
            img_tor = img_tor.to(device)
            pred_masks, _ = model(img_tor)

            # Assuming batch_size = 1
            #print(img_np.shape, pred_masks.shape)
            #img_np = np.squeeze(img_np.numpy(), axis=0)
            #img_np = img_np.astype(np.uint8)
            #img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
            pred_masks_raw = np.squeeze(pred_masks.cpu().numpy(), axis=(1)) * 255
            print(pred_masks_raw.shape)
            #pred_masks_round = np.squeeze(pred_masks.round().cpu().numpy(), axis=(0, 1))

            print('Batch :', batch_idx)
            #cv2.imshow('Input Image', img_np)
            #cv2.imshow('Generated Saliency Mask', pred_masks_raw)

            for im_idx in range(pred_masks_raw.shape[0]):
                out_path = os.path.join(args.output_folder, str(batch_idx)+"-"+str(im_idx)+"-subject_4.png")
                print(out_path)
                cv2.imwrite(out_path, pred_masks_raw[im_idx])
Пример #5
0
class Engine:
    def __init__(self, args):
        self.epochs = args.epochs
        self.bs = args.bs
        self.lr = args.lr
        self.wd = args.wd
        self.img_size = args.img_size
        self.aug = args.aug
        self.n_worker = args.n_worker
        self.test_interval = args.test_interval
        self.save_interval = args.save_interval
        self.save_opt = args.save_opt
        self.log_interval = args.log_interval
        self.res_mod_path = args.res_mod
        self.res_opt_path = args.res_opt
        self.use_gpu = args.use_gpu

        self.alpha_sal = args.alpha_sal
        self.wbce_w0 = args.wbce_w0
        self.wbce_w1 = args.wbce_w1

        self.model_path = args.base_save_path + '/alph-{}_wbce_w0-{}_w1-{}'.format(str(self.alpha_sal), str(self.wbce_w0), str(self.wbce_w1))
        print('Models would be saved at : {}\n'.format(self.model_path))
        if not os.path.exists(os.path.join(self.model_path, 'weights')):
            os.makedirs(os.path.join(self.model_path, 'weights'))
        if not os.path.exists(os.path.join(self.model_path, 'optimizers')):
            os.makedirs(os.path.join(self.model_path, 'optimizers'))

        if torch.cuda.is_available():
            self.device = torch.device(device='cuda')
        else:
            self.device = torch.device(device='cpu')

        self.model = SODModel()
        self.model.to(self.device)
        self.criterion = EdgeSaliencyLoss(device=self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd)

        # Load model and optimizer if resumed
        if self.res_mod_path is not None:
            chkpt = torch.load(self.res_mod_path, map_location=self.device)
            self.model.load_state_dict(chkpt['model'])
            print("Resuming training with checkpoint : {}\n".format(self.res_mod_path))
        if self.res_opt_path is not None:
            chkpt = torch.load(self.res_opt_path, map_location=self.device)
            self.optimizer.load_state_dict(chkpt['optimizer'])
            print("Resuming training with optimizer : {}\n".format(self.res_opt_path))

        self.train_data = SODLoader(mode='train', augment_data=self.aug, target_size=self.img_size)
        self.test_data = SODLoader(mode='test', augment_data=False, target_size=self.img_size)
        self.train_dataloader = DataLoader(self.train_data, batch_size=self.bs, shuffle=True, num_workers=self.n_worker)
        self.test_dataloader = DataLoader(self.test_data, batch_size=self.bs, shuffle=False, num_workers=self.n_worker)

    def train(self):
        best_test_mae = float('inf')
        for epoch in range(self.epochs):
            self.model.train()
            for batch_idx, (inp_imgs, gt_masks) in enumerate(self.train_dataloader):
                inp_imgs = inp_imgs.to(self.device)
                gt_masks = gt_masks.to(self.device)

                self.optimizer.zero_grad()
                pred_masks, ca_act_reg = self.model(inp_imgs)
                loss = self.criterion(pred_masks, gt_masks) + ca_act_reg  # Activity regularizer from Channel-wise Att.

                loss.backward()
                self.optimizer.step()

                if batch_idx % self.log_interval == 0:
                    print('TRAIN :: Epoch : {}\tBatch : {}/{} ({:.2f}%)\t\tTot Loss : {:.4f}\tReg : {:.4f}'
                          .format(epoch + 1,
                                  batch_idx + 1, len(self.train_dataloader),
                                  (batch_idx + 1) * 100 / len(self.train_dataloader),
                                  loss.item(),
                                  ca_act_reg))

            # Validation
            if epoch % self.test_interval == 0 or epoch % self.save_interval == 0:
                te_avg_loss, te_acc, te_pre, te_rec, te_mae = self.test()
                mod_chkpt = {'epoch': epoch,
                            'test_mae' : float(te_mae),
                            'model' : self.model.state_dict(),
                            'test_loss': float(te_avg_loss),
                            'test_acc': float(te_acc),
                            'test_pre': float(te_pre),
                            'test_rec': float(te_rec)}

                if self.save_opt:
                    opt_chkpt = {'epoch': epoch,
                                'test_mae' : float(te_mae),
                                'optimizer': self.optimizer.state_dict(),
                                'test_loss': float(te_avg_loss),
                                'test_acc': float(te_acc),
                                'test_pre': float(te_pre),
                                'test_rec': float(te_rec)}

                # Save the best model
                if te_mae < best_test_mae:
                    best_test_mae = te_mae
                    torch.save(mod_chkpt, self.model_path + 'weights/best-model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
                               format(epoch, best_test_mae, te_avg_loss))
                    if self.save_opt:
                        torch.save(opt_chkpt, self.model_path + 'optimizers/best-opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
                                   format(epoch, best_test_mae, te_avg_loss))
                    print('Best Model Saved !!!\n')
                    continue
                
                # Save model at regular intervals
                if self.save_interval is not None and epoch % self.save_interval == 0:
                    torch.save(mod_chkpt, self.model_path + 'weights/model_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
                               format(epoch, te_mae, te_avg_loss))
                    if self.save_opt:
                        torch.save(opt_chkpt, self.model_path + 'optimizers/opt_epoch-{:03}_mae-{:.4f}_loss-{:.4f}.pth'.
                                   format(epoch, best_test_mae, te_avg_loss))
                    print('Model Saved !!!\n')
                    continue
            print('\n')

    def test(self):
        self.model.eval()

        tot_loss = 0
        tp_fp = 0   # TruePositive + TrueNegative, for accuracy
        tp = 0      # TruePositive
        pred_true = 0   # Number of '1' predictions, for precision
        gt_true = 0     # Number of '1's in gt mask, for recall
        mae_list = []   # List to save mean absolute error of each image

        with torch.no_grad():
            for batch_idx, (inp_imgs, gt_masks) in enumerate(self.test_dataloader, start=1):
                inp_imgs = inp_imgs.to(self.device)
                gt_masks = gt_masks.to(self.device)

                pred_masks, ca_act_reg = self.model(inp_imgs)
                loss = self.criterion(pred_masks, gt_masks) + ca_act_reg

                tot_loss += loss.item()

                tp_fp += (pred_masks.round() == gt_masks).float().sum()
                tp += torch.mul(pred_masks.round(), gt_masks).sum()
                pred_true += pred_masks.round().sum()
                gt_true += gt_masks.sum()

                # Record the absolute errors
                ae = torch.mean(torch.abs(pred_masks - gt_masks), dim=(1, 2, 3)).cpu().numpy()
                mae_list.extend(ae)

        avg_loss = tot_loss / batch_idx
        accuracy = tp_fp / (len(self.test_data) * self.img_size * self.img_size)
        precision = tp / pred_true
        recall = tp / gt_true
        mae = np.mean(mae_list)

        print('TEST :: MAE : {:.4f}\tACC : {:.4f}\tPRE : {:.4f}\tREC : {:.4f}\tAVG-LOSS : {:.4f}\n'.format(mae,
                                                                                             accuracy,
                                                                                             precision,
                                                                                             recall,
                                                                                             avg_loss))

        return avg_loss, accuracy, precision, recall, mae
    parser.add_argument('--use_gpu', action="store_true", help='Whether to use GPU or not')
    parser.add_argument('--no_activation', action="store_true", help='Whether to use activation function before output')

    return parser.parse_args()

if __name__ == '__main__':
    args = parse_arguments()

    # Determine device
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device(device='cuda')
    else:
        device = torch.device(device='cpu')

    # Load model
    model = SODModel(last_activation = not args.no_activation)
    chkpt = torch.load(args.model_path, map_location=device)
    model.load_state_dict(chkpt['model'])
    model.to(device)
    model.eval()

    batch_size = 1
    x = torch.randn(batch_size, 3, 256, 256, requires_grad=True)
    torch_out = model(x)

    torch.onnx.export(
        model,              # model being run
        x,                  # model input (or a tuple for multiple inputs)
        args.output,        # where to save the model (can be a file or file-like object)
        export_params=True,        # store the trained parameter weights inside the model file
        opset_version=11,          # the ONNX version to export the model to