Example #1
0
    def imageClean(self, imgs):
        pred = self.inferBatch(imgs)
        bt_size = len(imgs)
        images = np.squeeze(imgs, axis=3)

        image_upload_count = log_images(
            imgs, pred, 0, self.experiment,
            self.args.ckptpath)  #image_upload_count
Example #2
0
    def validate(self, loader, epoch, is_testing=False):
        "validate NN"
        if not is_testing: print('Validating NN')
        else: print('Testing NN')
        total_val_loss = 0.0
        #num_batches=len(loader)
        hist = np.zeros((self.num_classes, self.num_classes))

        image_upload_count = 0
        for idx, (images, labels) in enumerate(loader):
            images = images.numpy()
            labels = labels.numpy()
            val_loss, val_logit = self.sess.run(
                [self.loss, self.logit],
                feed_dict={
                    self.input_images:
                    images,  # check in, comment out in formal run
                    self.input_labels: labels,
                    self.phase_train: False
                })  #self.loss,val_loss,

            total_val_loss += val_loss
            hist += get_hist(val_logit, labels)
            #val_loss=total_val_loss / len(validateloader)*batch_size

            if epoch == self.args.max_epoch and image_upload_count < 1000:  # decide how many images to upload
                pred = val_logit.argmax(3)
                images = np.squeeze(images, axis=3)
                image_upload_count = log_images(images, pred,
                                                image_upload_count,
                                                self.experiment,
                                                self.args.ckptpath)

        avg_batch_loss = total_val_loss / idx
        cls_sample_nums = hist.sum(1).astype(float)
        capture_array = np.diag(hist)
        acc_total = capture_array.sum() / hist.sum()
        capture_rate_ls = []
        for cls in range(self.num_classes):
            if cls_sample_nums[cls] == 0:
                capture_rate = 0.0
            else:
                capture_rate = capture_array[cls] / cls_sample_nums[cls]
            capture_rate_ls.append(capture_rate)
        #iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
        #mean_iu=np.nanmean(iu)
        print(
            'VALID: Total accuracy: %f%%. Class 0 capture: %f%%. Class 1 capture: %f%%'
            % (acc_total * 100.0, capture_rate_ls[0] * 100.0,
               capture_rate_ls[1] * 100.0))
        return avg_batch_loss, acc_total, capture_rate_ls[0], capture_rate_ls[
            1]
Example #3
0
                      fake_image,
                      mask=mask,
                      mask_val=config.mask_values["non_lung_tissue"])
        writer.add_scalar("L1 diff/Train", l1_diff, epoch)

        f = create_figure([
            masked_image[0, 0, :, :], fake_image[0, 0, :, :], image[0, 0, :, :]
        ],
                          figsize=(12, 4))

        writer.add_figure("Image outputs/Real image, fake image, mask", f,
                          epoch)

        log_images([masked_image, fake_image, image],
                   path=config.image_logs,
                   run_id=start_time,
                   step=epoch,
                   context="train",
                   figsize=(12, 4))

        data = next(iter(valid_dataloader))
        valid_image, valid_masked_image, valid_mask = data
        valid_image, valid_masked_image = valid_image.float().to(
            device), valid_masked_image.float().to(device)
        generator.eval()
        valid_fake_image = generator(valid_masked_image)
        valid_image = valid_image.float().detach().cpu().numpy()
        valid_masked_image = valid_masked_image.float().detach().cpu().numpy()
        valid_fake_image = valid_fake_image.detach().cpu().numpy()

        log_data(valid_fake_image,
                 config.image_logs,
Example #4
0
    def validate(self):
        val_loss = utils.RunningAverage()
        val_accuracy_label1 = utils.RunningAverage()
        val_accuracy_label2 = utils.RunningAverage()
        print("Validation begins...")
        self.model2.eval()
        with torch.no_grad():
            for i, data in enumerate(self.val_loader):
                # save validation data
                _, input, target = data
                input, target = input.to(self.device), target.to(self.device)
                input, target = input.type(
                    torch.cuda.FloatTensor), target.type(
                        torch.cuda.FloatTensor)

                # normalize the input image
                input = input / torch.max(input)

                input_ds = torch.nn.functional.interpolate(
                    torch.squeeze(input),
                    (self.config.data_size, self.config.data_size),
                    mode='nearest').unsqueeze(1)
                # target = torch.nn.functional.interpolate(torch.squeeze(target), (self.config.data_size, self.config.data_size),
                #        mode='nearest').unsqueeze(1)

                # augment the data randomly 1/8
                random_num = random.randint(0, 7)
                input_ds = self.data_aug.forward(input_ds, random_num,
                                                 self.device)
                target = self.data_aug.forward(target, random_num, self.device)

                # forward pass for model
                output_ds, output_ds_last_decoder = self.model(input_ds)

                # forward pass for model2
                output = self.model2(output_ds_last_decoder)

                # compute loss and accuracy
                loss = self.loss_criterion(output, target)
                accuracy_indi, _ = self.accuracy_criterion(output, target)

                # update the running average of loss and accuracy values
                val_loss.update(loss, self.config.val_batch_size)
                val_accuracy_label1.update(accuracy_indi[1],
                                           self.config.val_batch_size)
                val_accuracy_label2.update(accuracy_indi[2],
                                           self.config.val_batch_size)

                # visualize the prediction results
                save_path_visual = os.path.join(self.config.checkpoint_dir,
                                                'visual')
                if not os.path.exists(save_path_visual):
                    os.mkdir(save_path_visual)
                utils.visualize_prediction(input, target, output, i,
                                           save_path_visual)
                # utils.visualize_difference(target, output, i, save_path_visual)

                # save prediction
                save_path_pred = os.path.join(self.config.checkpoint_dir,
                                              'pred')
                if not os.path.exists(save_path_pred):
                    os.mkdir(save_path_pred)
                save_name = f'patch{i}_pred.h5'
                utils.save_prediction(save_path_pred, save_name, 'raw',
                                      input.cpu(), 'pred', output.cpu(),
                                      'label', target.cpu())

            # display results for training
            prediction = torch.argmax(output, dim=1)
            prediction = torch.unsqueeze(prediction, dim=1)
            prediction = prediction.type(torch.cuda.FloatTensor)
            utils.log_images(writer=self.writer,
                             num_iter=self.num_iter,
                             name1='raw',
                             data1=input,
                             name2='target',
                             data2=target,
                             name3='prediction',
                             data3=prediction,
                             num_per_row=8)
            print(
                "========> Validation Iteration {}, Loss {:.02f}, Accuracy for label 1 {:.02f}, Accuracy for label 2 {:.02f}"
                .format(i, val_loss.avg, val_accuracy_label1.avg,
                        val_accuracy_label2.avg))

            # save trends
            self.dict_val_loss = utils.save_trends(
                self.dict_val_loss, self.num_epoch, val_loss.avg,
                os.path.join(self.config.checkpoint_dir, 'val_loss'))
            self.dict_val_accuracy_label1 = utils.save_trends(
                self.dict_val_accuracy_label1, self.num_epoch,
                val_accuracy_label1.avg,
                os.path.join(self.config.checkpoint_dir,
                             'val_accuracy_label1'))
            self.dict_val_accuracy_label2 = utils.save_trends(
                self.dict_val_accuracy_label2, self.num_epoch,
                val_accuracy_label2.avg,
                os.path.join(self.config.checkpoint_dir,
                             'val_accuracy_label2'))

        self.model2.train()
        return val_accuracy_label1.avg, val_accuracy_label2.avg, val_loss.avg
def train(data_folderpath='data/edges2shoes', image_size=256, ndf=64, ngf=64,
          lr_d=2e-4, lr_g=2e-4, n_iterations=int(1e6),
          batch_size=64, iters_per_checkpoint=100, n_checkpoint_samples=16,
          reconstruction_weight=100, out_dir='gan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)

    data_iterator = iterate_minibatches(
        data_folderpath + "/train/*.jpg", batch_size, image_size)
    val_data_iterator = iterate_minibatches(
        data_folderpath + "/val/*.jpg", n_checkpoint_samples, image_size)
    img_ab_fixed, _ = next(val_data_iterator)
    img_a_fixed, img_b_fixed = img_ab_fixed[:, 0], img_ab_fixed[:, 1]

    img_a_shape = img_a_fixed.shape[1:]
    img_b_shape = img_b_fixed.shape[1:]
    patch = int(img_a_shape[0] / 2**4)  # n_layers
    disc_patch = (patch, patch, 1)
    print("img a shape ", img_a_shape)
    print("img b shape ", img_b_shape)
    print("disc_patch ", disc_patch)

    # plot real text for reference
    log_images(img_a_fixed, 'real_a', '0', logger)
    log_images(img_b_fixed, 'real_b', '0', logger)

    # build models
    D = build_discriminator(
        img_a_shape, img_b_shape, ndf, activation='sigmoid')
    G = build_generator(img_a_shape, ngf)

    # build model outputs
    img_a_input = Input(shape=img_a_shape)
    img_b_input = Input(shape=img_b_shape)

    fake_samples = G(img_a_input)
    D_real = D([img_a_input, img_b_input])
    D_fake = D([img_a_input, fake_samples])

    loss_reconstruction = partial(mean_absolute_error,
                                  real_samples=img_b_input,
                                  fake_samples=fake_samples)
    loss_reconstruction.__name__ = 'loss_reconstruction'

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[D_real, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                    loss='binary_crossentropy')

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[D_fake, fake_samples])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=['binary_crossentropy', loss_reconstruction],
                    loss_weights=[1, reconstruction_weight])

    ones = np.ones((batch_size, ) + disc_patch, dtype=np.float32)
    zeros = np.zeros((batch_size, ) + disc_patch, dtype=np.float32)
    dummy = zeros

    for i in range(n_iterations):
        D.trainable = True
        G.trainable = False

        image_ab_batch, _ = next(data_iterator)
        loss_d = D_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, zeros])

        D.trainable = False
        G.trainable = True
        image_ab_batch, _ = next(data_iterator)
        loss_g = G_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, dummy])

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict(img_a_fixed)
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, out_dir)

        log_losses(loss_d, loss_g, i, logger)
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)
    print("Learning rate = ", args.lr)                     #AP knowing lr
    print("Batch-size = ", args.batch_size)  # AP knowing batch-size
    print("Number of visualization images to save in log file = ", args.vis_images)  # AP knowing batch-size

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])]
                        )
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])]
                        )
                        if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    )
                )
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt"))
                loss_valid = []

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
def train(n_channels=3,
          resolution=32,
          z_dim=128,
          n_labels=0,
          lr=1e-3,
          e_drift=1e-3,
          wgp_target=750,
          initial_resolution=4,
          total_kimg=25000,
          training_kimg=500,
          transition_kimg=500,
          iters_per_checkpoint=500,
          n_checkpoint_images=16,
          glob_str='cifar10',
          out_dir='cifar10'):

    # instantiate logger
    logger = SummaryWriter(out_dir)

    # load data
    batch_size = MINIBATCH_OVERWRITES[0]
    train_iterator = iterate_minibatches(glob_str, batch_size, resolution)

    # build models
    G = Generator(n_channels, resolution, z_dim, n_labels)
    D = Discriminator(n_channels, resolution, n_labels)

    G_train, D_train = GAN(G, D, z_dim, n_labels, resolution, n_channels)

    D_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    G_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

    # define loss functions
    D_loss = [loss_mean, loss_gradient_penalty, 'mse']
    G_loss = [loss_wasserstein]

    # compile graphs used during training
    G.compile(G_opt, loss=loss_wasserstein)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=[1, GP_WEIGHT, e_drift])

    # for computing the loss
    ones = np.ones((batch_size, 1), dtype=np.float32)
    zeros = ones * 0.0

    # fix a z vector for training evaluation
    z_fixed = np.random.normal(0, 1, size=(n_checkpoint_images, z_dim))

    # vars
    resolution_log2 = int(np.log2(resolution))
    starting_block = resolution_log2
    starting_block -= np.floor(np.log2(initial_resolution))
    cur_block = starting_block
    cur_nimg = 0

    # compute duration of each phase and use proxy to update minibatch size
    phase_kdur = training_kimg + transition_kimg
    phase_idx_prev = 0

    # offset variable for transitioning between blocks
    offset = 0
    i = 0
    while cur_nimg < total_kimg * 1000:
        # block processing
        kimg = cur_nimg / 1000.0
        phase_idx = int(np.floor((kimg + transition_kimg) / phase_kdur))
        phase_idx = max(phase_idx, 0.0)
        phase_kimg = phase_idx * phase_kdur

        # update batch size and ones vector if we switched phases
        if phase_idx_prev < phase_idx:
            batch_size = MINIBATCH_OVERWRITES[phase_idx]
            train_iterator = iterate_minibatches(glob_str, batch_size)
            ones = np.ones((batch_size, 1), dtype=np.float32)
            zeros = ones * 0.0
            phase_idx_prev = phase_idx

        # possibly gradually update current level of detail
        if transition_kimg > 0 and phase_idx > 0:
            offset = (kimg + transition_kimg - phase_kimg) / transition_kimg
            offset = min(offset, 1.0)
            offset = offset + phase_idx - 1
            cur_block = max(starting_block - offset, 0.0)

        # update level of detail
        K.set_value(G_train.cur_block, np.float32(cur_block))
        K.set_value(D_train.cur_block, np.float32(cur_block))

        # train D
        for j in range(N_CRITIC_ITERS):
            z = np.random.normal(0, 1, size=(batch_size, z_dim))
            real_batch = next(train_iterator)
            fake_batch = G.predict_on_batch([z])
            interpolated_batch = get_interpolated_images(
                real_batch, fake_batch)
            losses_d = D_train.train_on_batch(
                [real_batch, fake_batch, interpolated_batch],
                [ones, ones * wgp_target, zeros])
            cur_nimg += batch_size

        # train G
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        loss_g = G_train.train_on_batch(z, -1 * ones)

        logger.add_scalar("cur_block", cur_block, i)
        logger.add_scalar("learning_rate", lr, i)
        logger.add_scalar("batch_size", z.shape[0], i)
        print("iter", i, "cur_block", cur_block, "lr", lr, "kimg", kimg,
              "losses_d", losses_d, "loss_g", loss_g)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_images = G.predict(z_fixed)
            # log fake images
            log_images(fake_images, 'fake', i, logger, fake_images.shape[1],
                       fake_images.shape[2], int(np.sqrt(n_checkpoint_images)))

            # plot real images for reference
            log_images(real_batch[:n_checkpoint_images], 'real', i, logger,
                       real_batch.shape[1], real_batch.shape[2],
                       int(np.sqrt(n_checkpoint_images)))

            # save the model to eventually resume training or do inference
            save_model(G, out_dir + "/model.json", out_dir + "/model.h5")

        log_losses(losses_d, loss_g, i, logger)
        i += 1
Example #8
0
        print(f"""
                total_loss_G = {total_loss_G}\n
                loss_G_GAN = {loss_GAN_G12 + loss_GAN_G21}\n
                loss_G_Cycl = {cyclic_loss}\n
                loss_D = {loss_D_1 + loss_D_2}
            """)

        log_losses({
            "total_loss_G": total_loss_G.item(),
            "loss_G_GAN": (loss_GAN_G12 + loss_GAN_G21).item(),
            "loss_G_Cycl": cyclic_loss.item(),
            "loss_D": (loss_D_1 + loss_D_2).item()
        })

        log_images({
            "epoch": epoch,
            "batch": i,
            "images": {
                "realA": realA,
                "fakeA": fakeA,
                "realB": realB,
                "fakeB": fakeB
            }
        })

    #------------------- save model -----------------#
    save_model(G12, epoch)
    save_model(G21, epoch)
    save_model(D1, epoch)
    save_model(D2, epoch)
            real_A[0, 0, :, :], fake_image_B[0, 0, :, :],
            recovered_image_A[0, 0, :, :]
        ],
                          figsize=(12, 4))
        writer.add_figure("Image outputs/A to B to A", f, epoch)

        f = create_figure([
            real_B[0, 0, :, :], fake_image_A[0, 0, :, :],
            recovered_image_B[0, 0, :, :]
        ],
                          figsize=(12, 4))
        writer.add_figure("Image outputs/B to A to B", f, epoch)

        log_images([real_A, fake_image_B, recovered_image_A],
                   path=config.image_logs,
                   run_id=start_time,
                   step=epoch,
                   context="train_ABA",
                   figsize=(12, 4))

        log_images([real_B, fake_image_A, recovered_image_B],
                   path=config.image_logs,
                   run_id=start_time,
                   step=epoch,
                   context="train_BAB",
                   figsize=(12, 4))

        data = next(iter(valid_dataloader))
        real_A, real_B = data
        real_A, real_B = real_A.float().to(device), real_B.float().to(device)
        netG_B2A.eval()
        netG_A2B.eval()
Example #10
0
        optimizer_D.zero_grad()

        #calculate loss
        real_loss = adversarial_loss(discriminator(masked_parts), valid)
        fake_loss = adversarial_loss(discriminator(gen_parts.detach()), fake)
        d_loss = 0.5 * (real_loss + fake_loss)

        #update weights
        d_loss.backward()
        optimizer_D.step()

        #log to tensorboard
        writer.add_scalar('real loss', real_loss, current_step)
        writer.add_scalar('fake loss', fake_loss, current_step)
        writer.add_scalar('total d loss', d_loss, current_step)
        writer.add_scalar('adversarial loss', g_adv, current_step)
        writer.add_scalar('pixelwise loss', g_pixel, current_step)
        writer.add_scalar('total g loss', g_loss, current_step)

        #update tqdm bar
        if i % 50 == 0:
            t.set_description('epoch:{} g_loss:{:.4f} d_loss:{:.4f}'.format(
                epoch, g_loss.item(), d_loss.item()))
            t.refresh()

        if i % run_config['save_frequency'] == 0:

            generator.eval()
            log_images(writer, test_dataloader, mask_size, device, generator,
                       current_step)
Example #11
0
def train(data_filepath='data/flowers.hdf5',
          ndf=64,
          ngf=128,
          z_dim=128,
          emb_dim=128,
          lr_d=2e-4,
          lr_g=2e-4,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=500,
          n_checkpoint_samples=16,
          out_dir='gan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)
    train_data = get_data(data_filepath, 'train')
    val_data = get_data(data_filepath, 'valid')
    data_iterator = iterate_minibatches(train_data, batch_size)
    val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples)
    val_data = next(val_data_iterator)
    img_fixed = images_from_bytes(val_data[0])
    emb_fixed = val_data[1]
    txt_fixed = val_data[2]

    img_shape = img_fixed[0].shape
    emb_shape = emb_fixed[0].shape
    print("emb shape {}".format(img_shape))
    print("img shape {}".format(emb_shape))
    z_shape = (z_dim, )

    # plot real text for reference
    log_images(img_fixed, 'real', '0', logger)
    log_text(txt_fixed, 'real', '0', logger)

    # build models
    D = build_discriminator(img_shape,
                            emb_shape,
                            emb_dim,
                            ndf,
                            activation='sigmoid')
    G = build_generator(z_shape, emb_shape, emb_dim, ngf)

    # build model outputs
    real_inputs = Input(shape=img_shape)
    txt_inputs = Input(shape=emb_shape)
    txt_shuf_inputs = Input(shape=emb_shape)
    z_inputs = Input(shape=(z_dim, ))

    fake_samples = G([z_inputs, txt_inputs])
    D_real = D([real_inputs, txt_inputs])
    D_wrong = D([real_inputs, txt_shuf_inputs])
    D_fake = D([fake_samples, txt_inputs])

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(
        inputs=[real_inputs, txt_inputs, txt_shuf_inputs, z_inputs],
        outputs=[D_real, D_wrong, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.9),
                    loss='binary_crossentropy',
                    loss_weights=[1, 0.5, 0.5])

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[z_inputs, txt_inputs], outputs=D_fake)
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.9),
                    loss='binary_crossentropy')

    ones = np.ones((batch_size, 1, 1, 1), dtype=np.float32)
    zeros = np.zeros((batch_size, 1, 1, 1), dtype=np.float32)

    # fix a z vector for training evaluation
    z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim))

    for i in range(n_iterations):
        start = clock()
        D.trainable = True
        G.trainable = False
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        images_batch = images_from_bytes(real_batch[0])
        emb_text_batch = real_batch[1]
        ids = np.arange(len(emb_text_batch))
        np.random.shuffle(ids)
        emb_text_batch_shuffle = emb_text_batch[ids]
        loss_d = D_model.train_on_batch(
            [images_batch, emb_text_batch, emb_text_batch_shuffle, z],
            [ones, zeros, zeros])

        D.trainable = False
        G.trainable = True
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        loss_g = G_model.train_on_batch([z, real_batch[1]], ones)

        print("iter", i, "time", clock() - start)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict([z_fixed, emb_fixed])
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, 'gan')

        log_losses(loss_d, loss_g, i, logger)
def main(args):
    makedirs(args)
    snapshotargs(args)
    device = torch.device(
        "cpu" if not torch.cuda.is_available() else args.device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    #    unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels)
    #    unet = NestedUNet(in_ch=Dataset.in_channels, out_ch=Dataset.out_channels)
    unet = NestedUNet()
    unet.to(device)

    #    dsc_loss = Gen_dice_loss()
    dsc_loss = BCEDiceLoss()
    #    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=[30, 60],
                                               gamma=0.3)

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in range(args.epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)
                    print('epoch: ', epoch, ' | step: ', i, ' | loss : ',
                          loss.cpu().detach().numpy())

                    if phase == "valid":
                        loss_valid.append(loss.item())

                        y_pred = torch.sigmoid(y_pred)
                        y_pred_np = y_pred.detach().cpu().numpy()
                        #                        y_pred_np = (y_pred_np > 0.5)
                        y_pred_np = np.resize(y_pred_np, (1, 256, 256))

                        y_true_np = y_true.detach().cpu().numpy()
                        y_true_np = np.resize(y_true_np, (1, 256, 256))

                        if (np.any(y_true_np)):
                            validation_pred.append(y_pred_np)
                            validation_true.append(y_true_np)

                        if (epoch % args.vis_freq
                                == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                    ))
                logger.scalar_summary("val_dsc", mean_dsc, step)
                print('best_score: ', best_validation_dsc, ' | mean_score: ',
                      mean_dsc)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(),
                               os.path.join(args.weights, "unet.pt"))
                loss_valid = []

        scheduler.step()
    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
def train(data_folderpath='data/cityscapes',
          image_size=256,
          ndf=32,
          ngf=64,
          lr_d=2e-4,
          lr_g=2e-4,
          n_feat_channels=3,
          use_edges=True,
          n_iterations=int(1e6),
          batch_size=1,
          iters_per_checkpoint=100,
          n_checkpoint_samples=4,
          feature_matching_weight=10,
          out_dir='pix2pixhd'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)

    # instantiate training and validationm data
    data_iterator = iterate_minibatches_cityscapes(data_folderpath, "train",
                                                   use_edges, batch_size)
    val_data_iterator = iterate_minibatches_cityscapes(data_folderpath,
                                                       "train", use_edges,
                                                       n_checkpoint_samples)

    # instantiate fixed data for evaluation
    input_val, _ = next(val_data_iterator)
    img_fixed = input_val[..., :3]
    lbl_fixed = input_val[..., 3][..., None]
    edges_fixed = input_val[..., 4][..., None]

    # instantiate and report data shapes
    img_shape = img_fixed.shape[1:]
    lbl_shape = lbl_fixed.shape[1:]
    edges_shape = edges_fixed.shape[1:]
    h_patch = int(lbl_shape[0] / 2**3)  # n_layers
    w_patch = int(lbl_shape[1] / 2**3)  # n_layers
    disc_patch_0 = (h_patch, w_patch, 1)
    disc_patch_1 = (int(h_patch / 2), int(w_patch / 2), 1)
    disc_patch_2 = (int(h_patch / 4), int(w_patch / 4), 1)
    print("img shape ", img_shape)
    print("lbl shape ", lbl_shape)
    print("edges shape ", edges_shape)
    print("disc_patch ", disc_patch_0)
    print("disc_patch ", disc_patch_1)
    print("disc_patch ", disc_patch_2)

    # plot real data for reference
    plot_dims = int(np.sqrt(img_fixed.shape[0]))
    log_images(img_fixed, 'real_img', '0', logger, img_fixed.shape[1],
               img_fixed.shape[2], plot_dims)
    log_images(lbl_fixed, 'real_lbls', '0', logger, img_fixed.shape[1],
               img_fixed.shape[2], plot_dims)
    log_images(edges_fixed, 'real_edges', '0', logger, img_fixed.shape[1],
               img_fixed.shape[2], plot_dims)

    # build models
    D0 = build_discriminator(lbl_shape,
                             img_shape,
                             edges_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=0,
                             name='Discriminator0')
    D1 = build_discriminator(lbl_shape,
                             img_shape,
                             edges_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=1,
                             name='Discriminator1')
    D2 = build_discriminator(lbl_shape,
                             img_shape,
                             edges_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=2,
                             name='Discriminator2')
    G = build_global_generator(lbl_shape, edges_shape, ngf)

    # build model inputs and outputs
    lbl_input = Input(shape=lbl_shape)
    img_input = Input(shape=img_shape)
    edges_input = Input(shape=edges_shape)

    fake_samples = G([lbl_input, edges_input])[0]
    D0_real = D0([lbl_input, img_input, edges_input])[0]
    D0_fake = D0([lbl_input, fake_samples, edges_input])[0]

    D1_real = D1([lbl_input, img_input, edges_input])[0]
    D1_fake = D1([lbl_input, fake_samples, edges_input])[0]

    D2_real = D2([lbl_input, img_input, edges_input])[0]
    D2_fake = D2([lbl_input, fake_samples, edges_input])[0]

    # define graph and optimizer for the Discriminator
    G.trainable = False
    D0.trainable = True
    D1.trainable = True
    D2.trainable = True
    D0_model = Model([lbl_input, img_input, edges_input], [D0_real, D0_fake],
                     name='Discriminator0_model')
    D1_model = Model([lbl_input, img_input, edges_input], [D1_real, D1_fake],
                     name='Discriminator1_model')
    D2_model = Model([lbl_input, img_input, edges_input], [D2_real, D2_fake],
                     name='Discriminator2_model')
    D0_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])
    D1_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])
    D2_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])

    # define D(G(z)) loss, graph and optimizer
    G.trainable = True
    D0.trainable = False
    D1.trainable = False
    D2.trainable = False

    loss_fm0 = partial(loss_feature_matching,
                       lbl_map=lbl_input,
                       real_samples=img_input,
                       edges_map=edges_input,
                       D=D0,
                       feature_matching_weight=feature_matching_weight)
    loss_fm1 = partial(loss_feature_matching,
                       lbl_map=lbl_input,
                       real_samples=img_input,
                       edges_map=edges_input,
                       D=D1,
                       feature_matching_weight=feature_matching_weight)
    loss_fm2 = partial(loss_feature_matching,
                       lbl_map=lbl_input,
                       real_samples=img_input,
                       edges_map=edges_input,
                       D=D2,
                       feature_matching_weight=feature_matching_weight)

    G_model = Model(inputs=[lbl_input, img_input, edges_input],
                    outputs=[
                        D0_fake, D1_fake, D2_fake, fake_samples, fake_samples,
                        fake_samples
                    ])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=['mse', 'mse', 'mse', loss_fm0, loss_fm1, loss_fm2])

    # instantiate variables for computing the loss
    ones_0 = np.ones((batch_size, ) + disc_patch_0, dtype=np.float32)
    ones_1 = np.ones((batch_size, ) + disc_patch_1, dtype=np.float32)
    ones_2 = np.ones((batch_size, ) + disc_patch_2, dtype=np.float32)
    zeros_0 = np.zeros((batch_size, ) + disc_patch_0, dtype=np.float32)
    zeros_1 = np.zeros((batch_size, ) + disc_patch_1, dtype=np.float32)
    zeros_2 = np.zeros((batch_size, ) + disc_patch_2, dtype=np.float32)
    dummy = np.ones((batch_size, ), dtype=np.float32)

    # training loop
    for i in range(n_iterations):
        # train discriminators only
        D0.trainable = True
        D1.trainable = True
        D2.trainable = True
        G.trainable = False

        # sample batch of data
        batch, _ = next(data_iterator)
        img = batch[..., :3]
        segmap = batch[..., 3][..., None]
        edges = batch[..., 4][..., None]
        fake_image = G.predict([segmap, edges])[0]

        # compute loss on current batch
        loss_d0 = D0_model.train_on_batch([segmap, img, edges],
                                          [ones_0, zeros_0])
        loss_d1 = D1_model.train_on_batch([segmap, img, edges],
                                          [ones_1, zeros_1])
        loss_d2 = D2_model.train_on_batch([segmap, img, edges],
                                          [ones_2, zeros_2])
        # train generator only
        D0.trainable = False
        D1.trainable = False
        D2.trainable = False
        G.trainable = True

        # sample batch of data
        batch, _ = next(data_iterator)
        img = batch[..., :3]
        segmap = batch[..., 3][..., None]
        edges = batch[..., 4][..., None]

        # compute loss on current batch
        loss_g = G_model.train_on_batch(
            [segmap, img, edges],
            [ones_0, ones_1, ones_2, dummy, dummy, dummy])

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict([lbl_fixed, edges_fixed])[0]
            log_images(fake_image, 'val_fake', i, logger, fake_image.shape[1],
                       fake_image.shape[2], int(np.sqrt(fake_image.shape[0])))
            save_model(G, out_dir)

        log_losses_pix2pixhd([loss_d0, loss_d1, loss_d2], loss_g, i, logger)
Example #14
0
def main(args):

    global run
    run = Run.get_context()

    print("Current directory:", os.getcwd())
    print("Data directory:", args.images)
    print("Training directory content:", os.listdir(args.images))

    makedirs(args)
    snapshotargs(args)

    device = torch.device(
        "cpu" if not torch.cuda.is_available() else args.device)
    print("Using device:", device)

    loader_train, loader_valid = data_loaders(args)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels,
                out_channels=Dataset.out_channels)

    unet = unet.to(device)
    unet = torch.nn.DataParallel(unet)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    optimizer = optim.Adam(unet.parameters(), lr=args.lr)

    logger = Logger(args.logs)
    loss_train = []
    loss_valid = []

    step = 0

    for epoch in tqdm(range(args.epochs), total=args.epochs):
        for phase in ["train", "valid"]:

            start = time.time()

            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])])
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])])
                        if (epoch % args.vis_freq
                                == 0) or (epoch == args.epochs - 1):
                            if i * args.batch_size < args.vis_images:
                                tag = "image/{}".format(i)
                                num_images = args.vis_images - i * args.batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    ))
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    #torch.save(unet.state_dict(), os.path.join(args.weights, "unet_epoch_" + str(epoch) + ".pt"))
                    torch.save(unet.state_dict(),
                               os.path.join(args.weights, "unet.pt"))
                loss_valid = []

            run.log("time_" + phase, time.time() - start)

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
    run.log("best_validation_mean_dsv", best_validation_dsc)
Example #15
0
def main():

    weights = './weights'
    logs = './logs'
    makedirs(weights, logs)
    #snapshot(logs)
    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda:0")

    images = './kaggle_3m'
    image_size = 256
    scale = 0.05
    angle = 15
    batch_size = 16
    workers = 4

    loader_train, loader_valid = data_loaders(images, image_size, scale, angle,
                                              batch_size, workers)
    loaders = {"train": loader_train, "valid": loader_valid}

    unet = UNet(in_channels=Dataset.in_channels,
                out_channels=Dataset.out_channels)
    unet.to(device)

    dsc_loss = DiceLoss()
    best_validation_dsc = 0.0

    lr = 0.0001

    optimizer = optim.Adam(unet.parameters(), lr)

    logger = Logger(logs)
    loss_train = []
    loss_valid = []

    step = 0
    epochs = 100
    vis_images = 200
    vis_freq = 10

    for epoch in tqdm(range(epochs), total=epochs):
        for phase in ["train", "valid"]:
            if phase == "train":
                unet.train()
            else:
                unet.eval()

            validation_pred = []
            validation_true = []

            for i, data in enumerate(loaders[phase]):
                if phase == "train":
                    step += 1

                x, y_true = data
                x, y_true = x.to(device), y_true.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == "train"):
                    y_pred = unet(x)

                    loss = dsc_loss(y_pred, y_true)

                    if phase == "valid":
                        loss_valid.append(loss.item())
                        y_pred_np = y_pred.detach().cpu().numpy()
                        validation_pred.extend(
                            [y_pred_np[s] for s in range(y_pred_np.shape[0])])
                        y_true_np = y_true.detach().cpu().numpy()
                        validation_true.extend(
                            [y_true_np[s] for s in range(y_true_np.shape[0])])
                        if (epoch % vis_freq == 0) or (epoch == epochs - 1):
                            if i * batch_size < vis_images:
                                tag = "image/{}".format(i)
                                num_images = vis_images - i * batch_size
                                logger.image_list_summary(
                                    tag,
                                    log_images(x, y_true, y_pred)[:num_images],
                                    step,
                                )

                    if phase == "train":
                        loss_train.append(loss.item())
                        loss.backward()
                        optimizer.step()

                if phase == "train" and (step + 1) % 10 == 0:
                    log_loss_summary(logger, loss_train, step)
                    loss_train = []

            if phase == "valid":
                log_loss_summary(logger, loss_valid, step, prefix="val_")
                mean_dsc = np.mean(
                    dsc_per_volume(
                        validation_pred,
                        validation_true,
                        loader_valid.dataset.patient_slice_index,
                    ))
                logger.scalar_summary("val_dsc", mean_dsc, step)
                if mean_dsc > best_validation_dsc:
                    best_validation_dsc = mean_dsc
                    torch.save(unet.state_dict(),
                               os.path.join(weights, "unet.pt"))
                loss_valid = []

    print("Best validation mean DSC: {:4f}".format(best_validation_dsc))
Example #16
0
def train(data_folderpath='data/edges2shoes',
          image_size=256,
          ndf=64,
          ngf=64,
          lr_d=2e-4,
          lr_g=2e-4,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=100,
          n_checkpoint_samples=16,
          feature_matching_weight=10,
          out_dir='lsgan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)

    data_iterator = iterate_minibatches(data_folderpath + "/train/*.jpg",
                                        batch_size, image_size)
    val_data_iterator = iterate_minibatches(data_folderpath + "/val/*.jpg",
                                            n_checkpoint_samples, image_size)
    img_ab_fixed, _ = next(val_data_iterator)
    img_a_fixed, img_b_fixed = img_ab_fixed[:, 0], img_ab_fixed[:, 1]

    img_a_shape = img_a_fixed.shape[1:]
    img_b_shape = img_b_fixed.shape[1:]
    patch = int(img_a_shape[0] / 2**3)  # n_layers
    disc_patch_0 = (patch, patch, 1)
    disc_patch_1 = (int(patch / 2), int(patch / 2), 1)
    disc_patch_2 = (int(patch / 4), int(patch / 4), 1)
    print("img a shape ", img_a_shape)
    print("img b shape ", img_b_shape)
    print("disc_patch ", disc_patch_0)
    print("disc_patch ", disc_patch_1)
    print("disc_patch ", disc_patch_2)

    # plot real text for reference
    log_images(img_a_fixed, 'real_a', '0', logger)
    log_images(img_b_fixed, 'real_b', '0', logger)

    # build models
    D0 = build_discriminator(img_a_shape,
                             img_b_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=0,
                             name='Discriminator0')
    D1 = build_discriminator(img_a_shape,
                             img_b_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=1,
                             name='Discriminator1')
    D2 = build_discriminator(img_a_shape,
                             img_b_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=2,
                             name='Discriminator2')
    G = build_global_generator(img_a_shape, ngf)

    # build model outputs
    img_a_input = Input(shape=img_a_shape)
    img_b_input = Input(shape=img_b_shape)

    fake_samples = G(img_a_input)[0]
    D0_real = D0([img_a_input, img_b_input])[0]
    D0_fake = D0([img_a_input, fake_samples])[0]

    D1_real = D1([img_a_input, img_b_input])[0]
    D1_fake = D1([img_a_input, fake_samples])[0]

    D2_real = D2([img_a_input, img_b_input])[0]
    D2_fake = D2([img_a_input, fake_samples])[0]

    # define D graph and optimizer
    G.trainable = False
    D0.trainable = True
    D1.trainable = True
    D2.trainable = True
    D0_model = Model([img_a_input, img_b_input], [D0_real, D0_fake],
                     name='Discriminator0_model')
    D1_model = Model([img_a_input, img_b_input], [D1_real, D1_fake],
                     name='Discriminator1_model')
    D2_model = Model([img_a_input, img_b_input], [D2_real, D2_fake],
                     name='Discriminator2_model')
    D0_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])
    D1_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])
    D2_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D0.trainable = False
    D1.trainable = False
    D2.trainable = False

    loss_fm0 = partial(loss_feature_matching,
                       image_input=img_a_input,
                       real_samples=img_b_input,
                       D=D0,
                       feature_matching_weight=feature_matching_weight)
    loss_fm1 = partial(loss_feature_matching,
                       image_input=img_a_input,
                       real_samples=img_b_input,
                       D=D1,
                       feature_matching_weight=feature_matching_weight)
    loss_fm2 = partial(loss_feature_matching,
                       image_input=img_a_input,
                       real_samples=img_b_input,
                       D=D2,
                       feature_matching_weight=feature_matching_weight)

    G_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[
                        D0_fake, D1_fake, D2_fake, fake_samples, fake_samples,
                        fake_samples
                    ])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=['mse', 'mse', 'mse', loss_fm0, loss_fm1, loss_fm2])

    ones_0 = np.ones((batch_size, ) + disc_patch_0, dtype=np.float32)
    ones_1 = np.ones((batch_size, ) + disc_patch_1, dtype=np.float32)
    ones_2 = np.ones((batch_size, ) + disc_patch_2, dtype=np.float32)
    zeros_0 = np.zeros((batch_size, ) + disc_patch_0, dtype=np.float32)
    zeros_1 = np.zeros((batch_size, ) + disc_patch_1, dtype=np.float32)
    zeros_2 = np.zeros((batch_size, ) + disc_patch_2, dtype=np.float32)
    dummy = np.ones((batch_size, ), dtype=np.float32)

    for i in range(n_iterations):
        D0.trainable = True
        D1.trainable = True
        D2.trainable = True
        G.trainable = False

        image_ab_batch, _ = next(data_iterator)
        fake_image = G.predict(image_ab_batch[:, 0])[0]
        loss_d0 = D0_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones_0, zeros_0])
        loss_d1 = D0_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones_1, zeros_1])
        loss_d2 = D0_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones_2, zeros_2])

        D0.trainable = False
        D1.trainable = False
        D2.trainable = False
        G.trainable = True
        image_ab_batch, _ = next(data_iterator)
        loss_g = G_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, ones, ones, dummy, dummy, dummy])

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict(img_a_fixed)
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, out_dir)

        log_losses(loss_d, loss_g, i, logger)