Beispiel #1
0
def test_shallow_copy_none() -> None:
    cfg = DictConfig(content=None)
    c = cfg.copy()
    c._set_value({"foo": 1})
    assert c.foo == 1
    assert cfg._is_none()
Beispiel #2
0
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    img_names = ["flower_foveon.ppm", "big_building.ppm", "bridge.ppm"]
    img_dict = {}
    quality_ll = np.linspace(1, 100, 20)

    # Initialize directories
    path = Path(f"{hydra.utils.get_original_cwd()}/outputs/plots/")
    path.mkdir(exist_ok=True, parents=True)

    for img_name in img_names:
        single_cfg = cfg.copy()
        single_cfg.img.path = str(Path(single_cfg.img.path).parent / img_name)

        # Load normalised image
        img = load_img(**single_cfg.img)  # H x W x 3
        img = (img * 255.0).numpy()[:, :, ::-1]

        img_name = Path(single_cfg.img.path).stem

        dump_path = Path("/tmp/jpeg_dump")
        dump_path.mkdir(exist_ok=True, parents=True)

        psnr_ll = []
        size_ll = []

        for quality in quality_ll:
            encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
            result, encimg = cv2.imencode(".jpg", img, encode_param)
            assert result, f"Could not encode image at quality {quality}"

            # decode from jpeg format
            dump_file = dump_path / f"{img_name}_{quality}.jpg"
            decimg = cv2.imdecode(encimg, 1)
            cv2.imwrite(str(dump_file), decimg)

            psnr = 10 * np.log10(255**2 / ((decimg - img)**2).mean())
            psnr_ll.append(psnr)

            size = dump_file.stat().st_size
            size_ll.append(size // 1024)

        img_dict[img_name] = {
            "psnr": psnr_ll,
            "size": size_ll,
            "quality": quality_ll
        }

    # breakpoint()
    _plot_image_dict(img_dict, "quality", "psnr", path)
    _plot_image_dict(img_dict, "quality", "size", path)
    _plot_image_dict(img_dict, "size", "psnr", path)

    # Convert to list
    for name in img_dict:
        for metric, metric_ll in img_dict[name].items():
            if isinstance(metric_ll, np.ndarray):
                img_dict[name][metric] = metric_ll.tolist()

    with open(path.parent / "csv" / f"jpg_dump.json", "w") as f:
        json.dump(img_dict, f, indent=4, sort_keys=True)
Beispiel #3
0
def test_shallow_copy_missing() -> None:
    cfg = DictConfig(content=MISSING)
    c = cfg.copy()
    c._set_value({"foo": 1})
    assert c.foo == 1
    assert cfg._is_missing()
Beispiel #4
0
    def __init__(self, network: DictConfig, training: DictConfig):
        super().__init__()

        architecture = network.architecture.lower().strip()
        if architecture == "unet":
            Model = smp.Unet
        elif architecture in ["unetplusplus", "unet++"]:
            Model = smp.UnetPlusPlus
        elif architecture == "resunet":
            Model = ResUnet
        elif architecture in ["resunetplusplus", "resunet++"]:
            Model = ResUnetPlusPlus
        elif architecture in ["efficientunetplusplus", "efficientunet++"]:
            Model = EfficientUnetPlusPlus
        else:
            raise NotImplementedError(
                "Currently only Unet, ResUnet, Unet++, ResUnet++, and EfficientUnet++ architectures are supported"
            )

        # Model does not accept "architecture" as an argument, but we need to store it in hparams for inference
        # TODO: cleanup?
        clean_network_conf = network.copy()
        del clean_network_conf.architecture
        del clean_network_conf.losses
        n_classes = len(clean_network_conf.classes)
        del clean_network_conf.classes

        self.model = Model(**clean_network_conf, classes=n_classes)
        # self.model.apply(initialize_weights)

        self.save_hyperparameters()

        self.classes = self.hparams["network"]["classes"]
        self.classes_int = list(range(len(self.classes)))
        self.classes_int_wout_bg = [c for c in self.classes_int if c != 0]

        self.in_channels = self.hparams["network"]["in_channels"]

        # losses
        self.dice_loss = None
        self.focal_loss = None
        self.boundary_loss = None

        # parse loss config
        self.initial_alpha = 0.01  # make this a hyperparameter and/ or scale with epoch
        self.boundary_loss_ramped = False

        assert (
            ("GDICE" in network.losses) and (("DICE" in network.losses))
        ) is False, f"Only GDICE _OR_ DICE allowed {network.losses}"

        for loss_component in network.losses:
            if loss_component == "GDICE":
                # This the only required loss term
                self.dice_loss = GeneralizedDice(idc=self.classes_int_wout_bg)
            elif loss_component == "DICE":
                # This the only required loss term
                self.dice_loss = DiceLoss(idc=self.classes_int_wout_bg)
            elif loss_component == "FOCAL":
                self.focal_loss = FocalLoss(idc=self.classes_int, gamma=2)
            elif loss_component == "BOUNDARY":
                self.boundary_loss = BoundaryLoss(idc=self.classes_int_wout_bg)
            elif loss_component == "BOUNDARY-RAMPED":
                self.boundary_loss = BoundaryLoss(idc=self.classes_int_wout_bg)
                self.boundary_loss_ramped = True
            else:
                raise NotImplementedError(
                    f"The loss component <{loss_component}> is not recognized"
                )

        log.info(f"Losses: {network.losses}")

        # checks: we require GDICE!
        assert self.dice_loss is not None

        self.dice_metric = smp.utils.metrics.Fscore(
            ignore_channels=[0],
        )

        self.dice_metric_with_bg = smp.utils.metrics.Fscore()

        self.stats = {
            "train": Counter(),
            "val": Counter(),
            "test": Counter(),
        }