Exemplo n.º 1
0
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    set_determinism(12345)
    device = torch.device("cuda:0")

    # load real data
    mednist_url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1"
    md5_value = "0bc7306e7427e00ad1c5526a6677552d"
    extract_dir = "data"
    tar_save_path = os.path.join(extract_dir, "MedNIST.tar.gz")
    download_and_extract(mednist_url, tar_save_path, extract_dir, md5_value)
    hand_dir = os.path.join(extract_dir, "MedNIST", "Hand")
    real_data = [{
        "hand": os.path.join(hand_dir, filename)
    } for filename in os.listdir(hand_dir)]

    # define real data transforms
    train_transforms = Compose([
        LoadPNGD(keys=["hand"]),
        AddChannelD(keys=["hand"]),
        ScaleIntensityD(keys=["hand"]),
        RandRotateD(keys=["hand"], range_x=15, prob=0.5, keep_size=True),
        RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5),
        RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
        ToTensorD(keys=["hand"]),
    ])

    # create dataset and dataloader
    real_dataset = CacheDataset(real_data, train_transforms)
    batch_size = 300
    real_dataloader = DataLoader(real_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=10)

    # define function to process batchdata for input into discriminator
    def prepare_batch(batchdata):
        """
        Process Dataloader batchdata dict object and return image tensors for D Inferer
        """
        return batchdata["hand"]

    # define networks
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)

    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])

    # initialize both networks
    disc_net.apply(normal_init)
    gen_net.apply(normal_init)

    # input images are scaled to [0,1] so enforce the same of generated outputs
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)

    # create optimizers and loss functions
    learning_rate = 2e-4
    betas = (0.5, 0.999)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)

    disc_loss_criterion = torch.nn.BCELoss()
    gen_loss_criterion = torch.nn.BCELoss()
    real_label = 1
    fake_label = 0

    def discriminator_loss(gen_images, real_images):
        """
        The discriminator loss is calculated by comparing D
        prediction for real and generated images.

        """
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)

        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)

        return (genloss + realloss) / 2

    def generator_loss(gen_images):
        """
        The generator loss is calculated by determining how realistic
        the discriminator classifies the generated images.

        """
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    # initialize current run dir
    run_dir = "model_out"
    print("Saving model output to: %s " % run_dir)

    # create workflow handlers
    handlers = [
        StatsHandler(
            name="batch_training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(
            save_dir=run_dir,
            save_dict={
                "g_net": gen_net,
                "d_net": disc_net
            },
            save_interval=10,
            save_final=True,
            epoch_level=True,
        ),
    ]

    # define key metric
    key_train_metric = None

    # create adversarial trainer
    disc_train_steps = 5
    num_epochs = 50

    trainer = GanTrainer(
        device,
        num_epochs,
        real_dataloader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_prepare_batch=prepare_batch,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=handlers,
    )

    # run GAN training
    trainer.run()

    # Training completed, save a few random generated images.
    print("Saving trained generator sample output.")
    test_img_count = 10
    test_latents = make_latent(test_img_count, latent_size).to(device)
    fakes = gen_net(test_latents)
    for i, image in enumerate(fakes):
        filename = "gen-fake-final-%d.png" % (i)
        save_path = os.path.join(run_dir, filename)
        img_array = image[0].cpu().data.numpy()
        png_writer.write_png(img_array, save_path, scale=255)
Exemplo n.º 2
0
def run_training_test(root_dir, device="cuda:0"):
    real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    train_files = [{"reals": img} for img in zip(real_images)]

    # prepare real data
    train_transforms = Compose([
        LoadNiftid(keys=["reals"]),
        AsChannelFirstd(keys=["reals"]),
        ScaleIntensityd(keys=["reals"]),
        RandFlipd(keys=["reals"], prob=0.5),
        ToTensord(keys=["reals"]),
    ])
    train_ds = monai.data.CacheDataset(data=train_files,
                                       transform=train_transforms,
                                       cache_rate=0.5)
    train_loader = monai.data.DataLoader(train_ds,
                                         batch_size=2,
                                         shuffle=True,
                                         num_workers=4)

    learning_rate = 2e-4
    betas = (0.5, 0.999)
    real_label = 1
    fake_label = 0

    # create discriminator
    disc_net = Discriminator(in_shape=(1, 64, 64),
                             channels=(8, 16, 32, 64, 1),
                             strides=(2, 2, 2, 2, 1),
                             num_res_units=1,
                             kernel_size=5).to(device)
    disc_net.apply(normal_init)
    disc_opt = torch.optim.Adam(disc_net.parameters(),
                                learning_rate,
                                betas=betas)
    disc_loss_criterion = torch.nn.BCELoss()

    def discriminator_loss(gen_images, real_images):
        real = real_images.new_full((real_images.shape[0], 1), real_label)
        gen = gen_images.new_full((gen_images.shape[0], 1), fake_label)
        realloss = disc_loss_criterion(disc_net(real_images), real)
        genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen)
        return torch.div(torch.add(realloss, genloss), 2)

    # create generator
    latent_size = 64
    gen_net = Generator(latent_shape=latent_size,
                        start_shape=(latent_size, 8, 8),
                        channels=[32, 16, 8, 1],
                        strides=[2, 2, 2, 1])
    gen_net.apply(normal_init)
    gen_net.conv.add_module("activation", torch.nn.Sigmoid())
    gen_net = gen_net.to(device)
    gen_opt = torch.optim.Adam(gen_net.parameters(),
                               learning_rate,
                               betas=betas)
    gen_loss_criterion = torch.nn.BCELoss()

    def generator_loss(gen_images):
        output = disc_net(gen_images)
        cats = output.new_full(output.shape, real_label)
        return gen_loss_criterion(output, cats)

    key_train_metric = None

    train_handlers = [
        StatsHandler(
            name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        TensorBoardStatsHandler(
            log_dir=root_dir,
            tag_name="training_loss",
            output_transform=lambda x: {
                Keys.GLOSS: x[Keys.GLOSS],
                Keys.DLOSS: x[Keys.DLOSS]
            },
        ),
        CheckpointSaver(save_dir=root_dir,
                        save_dict={
                            "g_net": gen_net,
                            "d_net": disc_net
                        },
                        save_interval=2,
                        epoch_level=True),
    ]

    disc_train_steps = 2
    num_epochs = 5

    trainer = GanTrainer(
        device,
        num_epochs,
        train_loader,
        gen_net,
        gen_opt,
        generator_loss,
        disc_net,
        disc_opt,
        discriminator_loss,
        d_train_steps=disc_train_steps,
        latent_shape=latent_size,
        key_train_metric=key_train_metric,
        train_handlers=train_handlers,
    )
    trainer.run()

    return trainer.state
Exemplo n.º 3
0
class MaskGAN(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        self.generator = UNet(
            dimensions=3,
            in_channels=1,
            out_channels=2,
            channels=(64, 128, 258, 512, 1024),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
            dropout=0,
        )

        self.discriminator = Discriminator(
            in_shape=self.hparams.patch_size,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            norm=monai.networks.layers.Norm.BATCH,
        )

        self.generated_masks = None
        self.sample_masks = []

    # Data setup
    def setup(self, stage):
        data_df = pd.read_csv(
            '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv')

        train_imgs = data_df['IMAGE'][0:295].tolist()
        train_masks = data_df['SEGM'][0:295].tolist()

        train_dicts = [{
            'image': image,
            'mask': mask
        } for (image, mask) in zip(train_imgs, train_masks)]

        train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.15)

        # Basic transforms
        data_keys = ["image", "mask"]
        data_transforms = Compose([
            LoadNiftid(keys=data_keys),
            AddChanneld(keys=data_keys),
            NormalizeIntensityd(keys="image"),
            RandCropByPosNegLabeld(keys=data_keys,
                                   label_key="mask",
                                   spatial_size=self.hparams.patch_size,
                                   num_samples=4,
                                   image_key="image"),
        ])

        self.train_dataset = monai.data.CacheDataset(
            data=train_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0)

        self.val_dataset = monai.data.CacheDataset(
            data=val_dicts,
            transform=Compose([data_transforms,
                               ToTensord(keys=data_keys)]),
            cache_rate=1.0)

    def train_dataloader(self):
        return monai.data.DataLoader(self.train_dataset,
                                     batch_size=self.hparams.batch_size,
                                     shuffle=True,
                                     num_workers=self.hparams.num_workers)

    def val_dataloader(self):
        return monai.data.DataLoader(self.val_dataset,
                                     batch_size=self.hparams.batch_size,
                                     num_workers=self.hparams.num_workers)

    # Training setup
    def forward(self, image):
        return self.generator(image)

    def generator_loss(self, y_hat, y):
        dice_loss = monai.losses.DiceLoss(to_onehot_y=True, softmax=True)
        return dice_loss(y_hat, y)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        inputs, labels = batch['image'], batch['mask']
        batch_size = inputs.size(0)
        # Generator training
        if optimizer_idx == 0:
            self.generated_masks = self(inputs)

            # Loss from difference between real and generated masks
            g_loss = self.generator_loss(self.generated_masks, labels)

            # Loss from discriminator
            # The generator wants the discriminator to be wrong,
            # so the wrong labels are used
            fake_labels = torch.ones(batch_size, 1).cuda(inputs.device.index)
            d_loss = self.adversarial_loss(
                self.discriminator(
                    self.generated_masks.argmax(1).type(
                        torch.FloatTensor).cuda(inputs.device.index)),
                fake_labels)

            avg_loss = g_loss + 0.5 * d_loss

            self.logger.log_metrics({"g_train/g_loss": g_loss},
                                    self.global_step)
            self.logger.log_metrics({"g_train/d_loss": d_loss},
                                    self.global_step)
            self.logger.log_metrics({"g_train/tot_loss": avg_loss},
                                    self.global_step)
            return {'loss': avg_loss}

        # Discriminator trainig
        else:
            # Learning real masks
            real_labels = torch.ones(batch_size, 1).cuda(inputs.device.index)
            real_loss = self.adversarial_loss(
                self.discriminator(
                    labels.squeeze(1).type(torch.FloatTensor).cuda(
                        inputs.device.index)), real_labels)

            # Learning "fake" masks
            fake_labels = torch.zeros(batch_size, 1).cuda(inputs.device.index)
            fake_loss = self.adversarial_loss(
                self.discriminator(
                    self.generated_masks.argmax(1).detach().type(
                        torch.FloatTensor).cuda(inputs.device.index)),
                fake_labels)

            avg_loss = real_loss + fake_loss

            self.logger.log_metrics({"d_train/real_loss": real_loss},
                                    self.global_step)
            self.logger.log_metrics({"d_train/fake_loss": fake_loss},
                                    self.global_step)
            self.logger.log_metrics({"d_train/tot_loss": avg_loss},
                                    self.global_step)

            return {'loss': avg_loss}

    def configure_optimizers(self):
        lr = self.hparams.lr
        g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr)
        d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
        return [g_optimizer, d_optimizer], []

    def validation_step(self, batch, batch_idx):
        inputs, labels = (
            batch["image"],
            batch["mask"],
        )
        outputs = self(inputs)

        # Sample masks
        if self.current_epoch != 0:
            middle = int(outputs[0].argmax(0).shape[2] / 2)
            image = outputs[0].argmax(0)[:, :, middle].unsqueeze(0).detach()
            self.sample_masks.append(image)

        loss = self.generator_loss(outputs, labels)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        self.logger.log_metrics({"val/loss": avg_loss}, self.current_epoch)

        if self.current_epoch != 0:
            grid = torchvision.utils.make_grid(self.sample_masks)
            self.logger.experiment.add_image('sample_masks', grid,
                                             self.current_epoch)
            self.sample_masks = []

        return {"val_loss": avg_loss}