Esempio n. 1
0
 def test_raise(self):
     name: str = 'InvalidLossName'
     params: dict = {}
     with self.assertRaises(ValueError):
         cfg_tools.load_loss(name, **params)
Esempio n. 2
0
 def test_jaccard_loss(self):
     name: str = 'JaccardLoss'
     params: dict = {'num_classes': 5}
     criterion = cfg_tools.load_loss(name, **params)
     self.assertIsInstance(criterion, losses.JaccardLoss)
Esempio n. 3
0
 def test_cross_entropy(self):
     name: str = 'CrossEntropyLoss'
     params: dict = {}
     criterion = cfg_tools.load_loss(name, **params)
     self.assertIsInstance(criterion, nn.CrossEntropyLoss)
Esempio n. 4
0
 def test_focal_loss(self):
     name: str = 'FocalLoss'
     params: dict = {}
     criterion = cfg_tools.load_loss(name, **params)
     self.assertIsInstance(criterion, losses.FocalLoss)
    cfg_dict: Dict[str, Any] = utils.load_yaml(args.config)
    cfg: utils.DotDict = utils.DotDict(cfg_dict)
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = models.utils.load_model(num_classes=cfg.num_classes,
                                    architecture=cfg.model.architecture,
                                    backbone=cfg.model.backbone,
                                    pretrained=True)
    model.load_state_dict(torch.load(args.weights_path, map_location=device))
    model.eval()
    model = model.to(device)

    logger.info(f'Configurations: {cfg}')

    criterion = cfg_tools.load_loss(cfg.loss.name, **cfg.loss.params)

    _, X_test, _, y_test = load_dataset()

    dtest = SegmentationDataset(X=X_test,
                                y=y_test,
                                num_classes=cfg.num_classes,
                                img_size=cfg.img_size,
                                transforms=albu.core.serialization.from_dict(
                                    cfg.albumentations.eval))
    test_loader = torch.utils.data.DataLoader(dtest,
                                              batch_size=cfg.batch_size,
                                              shuffle=False,
                                              drop_last=False)

    model.eval()