Esempio n. 1
0
    def save(self,
             data: Union[torch.Tensor, np.ndarray],
             meta_data: Optional[Dict] = None) -> None:
        """
        Save data into a png file.
        The meta_data could optionally have the following keys:

            - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
            - ``'spatial_shape'`` -- for data output shape.
            - ``'patch_index'`` -- if the data is a patch of big image, append the patch index to filename.

        If meta_data is None, use the default index (starting from 0) as the filename.

        Args:
            data: target data content that to be saved as a png format file.
                Assuming the data shape are spatial dimensions.
                Shape of the spatial dimensions (C,H,W).
                C should be 1, 3 or 4
            meta_data: the meta data information corresponding to the data.

        Raises:
            ValueError: When ``data`` channels is not one of [1, 3, 4].

        See Also
            :py:meth:`monai.data.png_writer.write_png`

        """
        filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(
            self._data_index)
        self._data_index += 1
        spatial_shape = meta_data.get(
            "spatial_shape", None) if meta_data and self.resample else None
        patch_index = meta_data.get(Key.PATCH_INDEX,
                                    None) if meta_data else None

        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()

        path = create_file_basename(self.output_postfix, filename,
                                    self.output_dir, self.data_root_dir,
                                    patch_index)
        path = f"{path}{self.output_ext}"

        if data.shape[0] == 1:
            data = data.squeeze(0)
        elif 2 < data.shape[0] < 5:
            data = np.moveaxis(np.asarray(data), 0, -1)
        else:
            raise ValueError(
                f"Unsupported number of channels: {data.shape[0]}, available options are [1, 3, 4]"
            )

        write_png(
            np.asarray(data),
            file_name=path,
            output_spatial_shape=spatial_shape,
            mode=self.mode,
            scale=self.scale,
        )
Esempio n. 2
0
    def save(self, data, meta_data=None):
        """
        Save data into a png file.
        The metadata could optionally have the following keys:

            - ``'filename_or_obj'`` -- for output file name creation, corresponding to filename or object.
            - ``'spatial_shape'`` -- for data output shape.

        If meta_data is None, use the default index from 0 to save data instead.

        args:
            data (Tensor or ndarray): target data content that to be saved as a png format file.
                Assuming the data shape are spatial dimensions.
                Shape of the spatial dimensions (C,H,W).
                C should be 1, 3 or 4
            meta_data (dict): the meta data information corresponding to the data.

        See Also
            :py:meth:`monai.data.png_writer.write_png`
        """
        filename = meta_data["filename_or_obj"] if meta_data else str(
            self._data_index)
        self._data_index += 1
        spatial_shape = meta_data.get(
            "spatial_shape", None) if meta_data and self.resample else None

        if torch.is_tensor(data):
            data = data.detach().cpu().numpy()

        filename = create_file_basename(self.output_postfix, filename,
                                        self.output_dir)
        filename = f"{filename}{self.output_ext}"

        if data.shape[0] == 1:
            data = data.squeeze(0)
        elif 2 < data.shape[0] < 5:
            data = np.moveaxis(data, 0, -1)
        else:
            raise ValueError("PNG image should only have 1, 3 or 4 channels.")

        write_png(
            data,
            file_name=filename,
            output_shape=spatial_shape,
            interp_order=self.interp_order,
            mode=self.mode,
            cval=self.cval,
            scale=self.scale,
        )
Esempio n. 3
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    set_determinism(12345)
    device = torch.device("cuda:0")

    # load real data
    mednist_url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
    md5_value = "0bc7306e7427e00ad1c5526a6677552d"
    extract_dir = "data"
    tar_save_path = os.path.join(extract_dir, "MedNIST.tar.gz")
    download_and_extract(mednist_url, tar_save_path, extract_dir, md5_value)
    hand_dir = os.path.join(extract_dir, "MedNIST", "Hand")
    real_data = [{
        "hand": os.path.join(hand_dir, filename)
    } for filename in os.listdir(hand_dir)]

    # define real data transforms
    train_transforms = Compose([
        LoadPNGD(keys=["hand"]),
        AddChannelD(keys=["hand"]),
        ScaleIntensityD(keys=["hand"]),
        RandRotateD(keys=["hand"], range_x=15, prob=0.5, keep_size=True),
        RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5),
        RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensorD(keys=["hand"]),
    ])

    # create dataset and dataloader
    real_dataset = CacheDataset(real_data, train_transforms)
    batch_size = 300
    real_dataloader = DataLoader(real_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=10)

    # define function to process batchdata for input into discriminator
    def prepare_batch(batchdata):
        """
        Process Dataloader batchdata dict object and return image tensors for D Inferer
        """
        return batchdata["hand"]

    # define networks
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)

    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])

    # initialize both networks
    disc_net.apply(normal_init)
    gen_net.apply(normal_init)

    # input images are scaled to [0,1] so enforce the same of generated outputs
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)

    # create optimizers and loss functions
    learning_rate = 2e-4
    betas = (0.5, 0.999)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)

    disc_loss_criterion = torch.nn.BCELoss()
    gen_loss_criterion = torch.nn.BCELoss()
    real_label = 1
    fake_label = 0

    def discriminator_loss(gen_images, real_images):
        """
        The discriminator loss is calculated by comparing D
        prediction for real and generated images.

        """
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)

        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)

        return (genloss + realloss) / 2

    def generator_loss(gen_images):
        """
        The generator loss is calculated by determining how realistic
        the discriminator classifies the generated images.

        """
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    # initialize current run dir
    run_dir = "model_out"
    print("Saving model output to: %s " % run_dir)

    # create workflow handlers
    handlers = [
        StatsHandler(
            name="batch_training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(
            save_dir=run_dir,
            save_dict={
                "g_net": gen_net,
                "d_net": disc_net
            },
            save_interval=10,
            save_final=True,
            epoch_level=True,
        ),
    ]

    # define key metric
    key_train_metric = None

    # create adversarial trainer
    disc_train_steps = 5
    num_epochs = 50

    trainer = GanTrainer(
        device,
        num_epochs,
        real_dataloader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_prepare_batch=prepare_batch,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=handlers,
    )

    # run GAN training
    trainer.run()

    # Training completed, save a few random generated images.
    print("Saving trained generator sample output.")
    test_img_count = 10
    test_latents = make_latent(test_img_count, latent_size).to(device)
    fakes = gen_net(test_latents)
    for i, image in enumerate(fakes):
        filename = "gen-fake-final-%d.png" % (i)
        save_path = os.path.join(run_dir, filename)
        img_array = image[0].cpu().data.numpy()
        png_writer.write_png(img_array, save_path, scale=255)
Esempio n. 4
0
def save_generator_fakes(run_folder, g_output_tensor):
    for i, image in enumerate(g_output_tensor):
        filename = "gen-fake-%d.png" % (i)
        save_path = os.path.join(run_folder, filename)
        img_array = image[0].cpu().data.numpy()
        png_writer.write_png(img_array, save_path, scale=255)