def main(): monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) set_determinism(12345) device = torch.device("cuda:0") # load real data mednist_url = "https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz?dl=1" md5_value = "0bc7306e7427e00ad1c5526a6677552d" extract_dir = "data" tar_save_path = os.path.join(extract_dir, "MedNIST.tar.gz") download_and_extract(mednist_url, tar_save_path, extract_dir, md5_value) hand_dir = os.path.join(extract_dir, "MedNIST", "Hand") real_data = [{ "hand": os.path.join(hand_dir, filename) } for filename in os.listdir(hand_dir)] # define real data transforms train_transforms = Compose([ LoadPNGD(keys=["hand"]), AddChannelD(keys=["hand"]), ScaleIntensityD(keys=["hand"]), RandRotateD(keys=["hand"], range_x=15, prob=0.5, keep_size=True), RandFlipD(keys=["hand"], spatial_axis=0, prob=0.5), RandZoomD(keys=["hand"], min_zoom=0.9, max_zoom=1.1, prob=0.5), ToTensorD(keys=["hand"]), ]) # create dataset and dataloader real_dataset = CacheDataset(real_data, train_transforms) batch_size = 300 real_dataloader = DataLoader(real_dataset, batch_size=batch_size, shuffle=True, num_workers=10) # define function to process batchdata for input into discriminator def prepare_batch(batchdata): """ Process Dataloader batchdata dict object and return image tensors for D Inferer """ return batchdata["hand"] # define networks disc_net = Discriminator(in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5).to(device) latent_size = 64 gen_net = Generator(latent_shape=latent_size, start_shape=(latent_size, 8, 8), channels=[32, 16, 8, 1], strides=[2, 2, 2, 1]) # initialize both networks disc_net.apply(normal_init) gen_net.apply(normal_init) # input images are scaled to [0,1] so enforce the same of generated outputs gen_net.conv.add_module("activation", torch.nn.Sigmoid()) gen_net = gen_net.to(device) # create optimizers and loss functions learning_rate = 2e-4 betas = (0.5, 0.999) disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas) gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas) disc_loss_criterion = torch.nn.BCELoss() gen_loss_criterion = torch.nn.BCELoss() real_label = 1 fake_label = 0 def discriminator_loss(gen_images, real_images): """ The discriminator loss is calculated by comparing D prediction for real and generated images. """ real = real_images.new_full((real_images.shape[0], 1), real_label) gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) realloss = disc_loss_criterion(disc_net(real_images), real) genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) return (genloss + realloss) / 2 def generator_loss(gen_images): """ The generator loss is calculated by determining how realistic the discriminator classifies the generated images. """ output = disc_net(gen_images) cats = output.new_full(output.shape, real_label) return gen_loss_criterion(output, cats) # initialize current run dir run_dir = "model_out" print("Saving model output to: %s " % run_dir) # create workflow handlers handlers = [ StatsHandler( name="batch_training_loss", output_transform=lambda x: { Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS] }, ), CheckpointSaver( save_dir=run_dir, save_dict={ "g_net": gen_net, "d_net": disc_net }, save_interval=10, save_final=True, epoch_level=True, ), ] # define key metric key_train_metric = None # create adversarial trainer disc_train_steps = 5 num_epochs = 50 trainer = GanTrainer( device, num_epochs, real_dataloader, gen_net, gen_opt, generator_loss, disc_net, disc_opt, discriminator_loss, d_prepare_batch=prepare_batch, d_train_steps=disc_train_steps, latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=handlers, ) # run GAN training trainer.run() # Training completed, save a few random generated images. print("Saving trained generator sample output.") test_img_count = 10 test_latents = make_latent(test_img_count, latent_size).to(device) fakes = gen_net(test_latents) for i, image in enumerate(fakes): filename = "gen-fake-final-%d.png" % (i) save_path = os.path.join(run_dir, filename) img_array = image[0].cpu().data.numpy() png_writer.write_png(img_array, save_path, scale=255)
def run_training_test(root_dir, device="cuda:0"): real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz"))) train_files = [{"reals": img} for img in zip(real_images)] # prepare real data train_transforms = Compose([ LoadNiftid(keys=["reals"]), AsChannelFirstd(keys=["reals"]), ScaleIntensityd(keys=["reals"]), RandFlipd(keys=["reals"], prob=0.5), ToTensord(keys=["reals"]), ]) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4) learning_rate = 2e-4 betas = (0.5, 0.999) real_label = 1 fake_label = 0 # create discriminator disc_net = Discriminator(in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5).to(device) disc_net.apply(normal_init) disc_opt = torch.optim.Adam(disc_net.parameters(), learning_rate, betas=betas) disc_loss_criterion = torch.nn.BCELoss() def discriminator_loss(gen_images, real_images): real = real_images.new_full((real_images.shape[0], 1), real_label) gen = gen_images.new_full((gen_images.shape[0], 1), fake_label) realloss = disc_loss_criterion(disc_net(real_images), real) genloss = disc_loss_criterion(disc_net(gen_images.detach()), gen) return torch.div(torch.add(realloss, genloss), 2) # create generator latent_size = 64 gen_net = Generator(latent_shape=latent_size, start_shape=(latent_size, 8, 8), channels=[32, 16, 8, 1], strides=[2, 2, 2, 1]) gen_net.apply(normal_init) gen_net.conv.add_module("activation", torch.nn.Sigmoid()) gen_net = gen_net.to(device) gen_opt = torch.optim.Adam(gen_net.parameters(), learning_rate, betas=betas) gen_loss_criterion = torch.nn.BCELoss() def generator_loss(gen_images): output = disc_net(gen_images) cats = output.new_full(output.shape, real_label) return gen_loss_criterion(output, cats) key_train_metric = None train_handlers = [ StatsHandler( name="training_loss", output_transform=lambda x: { Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS] }, ), TensorBoardStatsHandler( log_dir=root_dir, tag_name="training_loss", output_transform=lambda x: { Keys.GLOSS: x[Keys.GLOSS], Keys.DLOSS: x[Keys.DLOSS] }, ), CheckpointSaver(save_dir=root_dir, save_dict={ "g_net": gen_net, "d_net": disc_net }, save_interval=2, epoch_level=True), ] disc_train_steps = 2 num_epochs = 5 trainer = GanTrainer( device, num_epochs, train_loader, gen_net, gen_opt, generator_loss, disc_net, disc_opt, discriminator_loss, d_train_steps=disc_train_steps, latent_shape=latent_size, key_train_metric=key_train_metric, train_handlers=train_handlers, ) trainer.run() return trainer.state
class MaskGAN(pl.LightningModule): def __init__(self, hparams): super().__init__() self.hparams = hparams self.generator = UNet( dimensions=3, in_channels=1, out_channels=2, channels=(64, 128, 258, 512, 1024), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, dropout=0, ) self.discriminator = Discriminator( in_shape=self.hparams.patch_size, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), norm=monai.networks.layers.Norm.BATCH, ) self.generated_masks = None self.sample_masks = [] # Data setup def setup(self, stage): data_df = pd.read_csv( '/data/shared/prostate/yale_prostate/input_lists/MR_yale.csv') train_imgs = data_df['IMAGE'][0:295].tolist() train_masks = data_df['SEGM'][0:295].tolist() train_dicts = [{ 'image': image, 'mask': mask } for (image, mask) in zip(train_imgs, train_masks)] train_dicts, val_dicts = train_test_split(train_dicts, test_size=0.15) # Basic transforms data_keys = ["image", "mask"] data_transforms = Compose([ LoadNiftid(keys=data_keys), AddChanneld(keys=data_keys), NormalizeIntensityd(keys="image"), RandCropByPosNegLabeld(keys=data_keys, label_key="mask", spatial_size=self.hparams.patch_size, num_samples=4, image_key="image"), ]) self.train_dataset = monai.data.CacheDataset( data=train_dicts, transform=Compose([data_transforms, ToTensord(keys=data_keys)]), cache_rate=1.0) self.val_dataset = monai.data.CacheDataset( data=val_dicts, transform=Compose([data_transforms, ToTensord(keys=data_keys)]), cache_rate=1.0) def train_dataloader(self): return monai.data.DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, num_workers=self.hparams.num_workers) def val_dataloader(self): return monai.data.DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) # Training setup def forward(self, image): return self.generator(image) def generator_loss(self, y_hat, y): dice_loss = monai.losses.DiceLoss(to_onehot_y=True, softmax=True) return dice_loss(y_hat, y) def adversarial_loss(self, y_hat, y): return F.binary_cross_entropy(y_hat, y) def training_step(self, batch, batch_idx, optimizer_idx): inputs, labels = batch['image'], batch['mask'] batch_size = inputs.size(0) # Generator training if optimizer_idx == 0: self.generated_masks = self(inputs) # Loss from difference between real and generated masks g_loss = self.generator_loss(self.generated_masks, labels) # Loss from discriminator # The generator wants the discriminator to be wrong, # so the wrong labels are used fake_labels = torch.ones(batch_size, 1).cuda(inputs.device.index) d_loss = self.adversarial_loss( self.discriminator( self.generated_masks.argmax(1).type( torch.FloatTensor).cuda(inputs.device.index)), fake_labels) avg_loss = g_loss + 0.5 * d_loss self.logger.log_metrics({"g_train/g_loss": g_loss}, self.global_step) self.logger.log_metrics({"g_train/d_loss": d_loss}, self.global_step) self.logger.log_metrics({"g_train/tot_loss": avg_loss}, self.global_step) return {'loss': avg_loss} # Discriminator trainig else: # Learning real masks real_labels = torch.ones(batch_size, 1).cuda(inputs.device.index) real_loss = self.adversarial_loss( self.discriminator( labels.squeeze(1).type(torch.FloatTensor).cuda( inputs.device.index)), real_labels) # Learning "fake" masks fake_labels = torch.zeros(batch_size, 1).cuda(inputs.device.index) fake_loss = self.adversarial_loss( self.discriminator( self.generated_masks.argmax(1).detach().type( torch.FloatTensor).cuda(inputs.device.index)), fake_labels) avg_loss = real_loss + fake_loss self.logger.log_metrics({"d_train/real_loss": real_loss}, self.global_step) self.logger.log_metrics({"d_train/fake_loss": fake_loss}, self.global_step) self.logger.log_metrics({"d_train/tot_loss": avg_loss}, self.global_step) return {'loss': avg_loss} def configure_optimizers(self): lr = self.hparams.lr g_optimizer = torch.optim.Adam(self.generator.parameters(), lr=lr) d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr) return [g_optimizer, d_optimizer], [] def validation_step(self, batch, batch_idx): inputs, labels = ( batch["image"], batch["mask"], ) outputs = self(inputs) # Sample masks if self.current_epoch != 0: middle = int(outputs[0].argmax(0).shape[2] / 2) image = outputs[0].argmax(0)[:, :, middle].unsqueeze(0).detach() self.sample_masks.append(image) loss = self.generator_loss(outputs, labels) return {"val_loss": loss} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() self.logger.log_metrics({"val/loss": avg_loss}, self.current_epoch) if self.current_epoch != 0: grid = torchvision.utils.make_grid(self.sample_masks) self.logger.experiment.add_image('sample_masks', grid, self.current_epoch) self.sample_masks = [] return {"val_loss": avg_loss}