def train_networks(encoder,
                   ca,
                   c_pro_gan,
                   dataset,
                   validation_dataset,
                   epochs,
                   encoder_optim,
                   ca_optim,
                   fade_in_percentage,
                   batch_sizes,
                   start_depth,
                   num_workers,
                   feedback_factor,
                   log_dir,
                   sample_dir,
                   checkpoint_factor,
                   save_dir,
                   comment,
                   use_matching_aware_dis=True):
    # required only for type checking
    from networks.TextEncoder import PretrainedEncoder
    # Writer will output to ./runs/ directory by default
    writer = SummaryWriter(comment="_{}_{}".format(batch_sizes[0], comment))

    # input assertions
    assert c_pro_gan.depth == len(
        batch_sizes), "batch_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        epochs), "epochs_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        fade_in_percentage), "fip_sizes not compatible with depth"

    # put all the Networks in training mode:
    ca.train()
    c_pro_gan.gen.train()
    c_pro_gan.dis.train()

    if not isinstance(encoder, PretrainedEncoder):
        encoder.train()

    print("Starting the training process ... ")

    # create fixed_input for debugging
    temp_data = dl.get_data_loader(dataset,
                                   batch_sizes[start_depth],
                                   num_workers=num_workers)
    fixed_captions, fixed_real_images = iter(temp_data).next()
    fixed_embeddings = encoder(fixed_captions)
    fixed_embeddings = th.from_numpy(fixed_embeddings).to(device)  # shape 4096

    fixed_c_not_hats, _, _ = ca(fixed_embeddings)  # shape 1, 256

    fixed_noise = th.randn(len(fixed_captions), c_pro_gan.latent_size -
                           fixed_c_not_hats.shape[-1]).to(
                               device)  # shape batch_size, 256

    fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)

    # save the fixed_images once:
    fixed_save_dir = os.path.join(sample_dir, "__Real_Info")
    os.makedirs(fixed_save_dir, exist_ok=True)
    create_grid(
        fixed_real_images,
        None,  # scale factor is not required here
        os.path.join(fixed_save_dir, "real_samples.png"),
        real_imgs=True)
    create_descriptions_file(os.path.join(fixed_save_dir, "real_captions.txt"),
                             fixed_captions, dataset)

    # create a global time counter
    global_time = time.time()

    # delete temp data loader:
    del temp_data
    for current_depth in range(start_depth, c_pro_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth],
                                  num_workers)

        ticker = 1

        gen_losses = []
        dis_losses = []
        kl_losses = []
        val_gen_losses = []
        val_dis_losses = []
        val_kl_losses = []

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))
            fader_point = int((fade_in_percentage[current_depth] / 100) *
                              epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                captions, images = batch
                if encoder_optim is not None:
                    captions = captions.to(device)

                images = images.to(device)

                # perform text_work:
                embeddings = th.from_numpy(encoder(captions)).to(device)
                if encoder_optim is None:
                    # detach the LSTM from backpropagation
                    embeddings = embeddings.detach()
                c_not_hats, mus, sigmas = ca(embeddings)

                z = th.randn(len(captions), c_pro_gan.latent_size -
                             c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                # optimize the discriminator:
                dis_loss = c_pro_gan.optimize_discriminator(
                    gan_input, images, embeddings.detach(), current_depth,
                    alpha, use_matching_aware_dis)

                dis_losses.append(dis_loss)
                writer.add_scalar(
                    f"Batch/Discriminator_Loss/{current_depth}/{epoch}",
                    dis_loss, i)

                # optimize the generator:
                z = th.randn(
                    captions.shape[0]
                    if isinstance(captions, th.Tensor) else len(captions),
                    c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                if encoder_optim is not None:
                    encoder_optim.zero_grad()

                ca_optim.zero_grad()
                gen_loss = c_pro_gan.optimize_generator(
                    gan_input, embeddings, current_depth, alpha)
                gen_losses.append(gen_loss)
                writer.add_scalar(
                    f"Batch/Generator_Loss/{current_depth}/{epoch}", gen_loss,
                    i)
                # once the optimize_generator is called, it also sends gradients
                # to the Conditioning Augmenter and the TextEncoder. Hence the
                # zero_grad statements prior to the optimize_generator call
                # now perform optimization on those two as well
                # obtain the loss (KL divergence from ca_optim)
                kl_loss = th.mean(0.5 * th.sum(
                    (mus**2) + (sigmas**2) - th.log((sigmas**2)) - 1, dim=1))
                writer.add_scalar(f"Batch/KL_Loss/{current_depth}/{epoch}",
                                  kl_loss.item(), i)
                kl_losses.append(kl_loss.item())
                kl_loss.backward()
                ca_optim.step()
                if encoder_optim is not None:
                    encoder_optim.step()

                writer.add_image(
                    f"Batch/{current_depth}/{epoch}",
                    create_grid(
                        samples=c_pro_gan.gen(fixed_gan_input, current_depth,
                                              alpha),
                        scale_factor=int(
                            np.power(2, c_pro_gan.depth - current_depth - 1)),
                        img_file=None,  # if none we get the image grid returned
                    ),
                    i)
                # add an evaluation loop
                if i % 100 == 0:

                    v_temp_data = dl.get_data_loader(validation_dataset,
                                                     batch_sizes[start_depth],
                                                     num_workers=num_workers)
                    v_fixed_captions, v_fixed_real_images = iter(
                        v_temp_data).next()
                    v_fixed_embeddings = encoder(v_fixed_captions)
                    v_fixed_embeddings = th.from_numpy(v_fixed_embeddings).to(
                        device)  # shape 4096

                    v_fixed_c_not_hats, _, _ = ca(
                        v_fixed_embeddings)  # shape 1, 256

                    v_fixed_noise = th.randn(
                        len(v_fixed_captions), c_pro_gan.latent_size -
                        v_fixed_c_not_hats.shape[-1]).to(
                            device)  # shape batch_size, 256

                    v_fixed_gan_input = th.cat(
                        (v_fixed_c_not_hats, v_fixed_noise), dim=-1)

                    v_dis_loss = c_pro_gan.optimize_discriminator(
                        v_fixed_gan_input,
                        images,
                        embeddings.detach(),
                        current_depth,
                        alpha,
                        use_matching_aware_dis,
                        trainable=False)
                    v_gen_loss = c_pro_gan.optimize_generator(
                        v_fixed_gan_input,
                        embeddings,
                        current_depth,
                        alpha,
                        trainable=False)

                    val_dis_losses.append(v_dis_loss)
                    val_gen_losses.append(v_dis_loss)

                    writer.add_scalar(
                        f"Batch/Val/Discriminator_Loss/{current_depth}/{epoch}",
                        v_dis_loss, i)
                    writer.add_scalar(
                        f"Batch/Val/Generator_Loss/{current_depth}/{epoch}",
                        v_gen_loss, i)
                    writer.add_text(
                        f"Batch/Val/Captions/{current_depth}/{epoch}",
                        str(v_fixed_captions), i)
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print(
                        "Validation [%s]  batch: %d  d_loss: %f  g_loss: %f  kl_los: %f"
                        % (elapsed, i, v_dis_loss, v_gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    os.makedirs(log_dir, exist_ok=True)
                    log_file = os.path.join(
                        log_dir, "val_loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(
                            str(v_dis_loss) + "\t" + str(v_gen_loss) + "\t" +
                            str(kl_loss.item()) + "\n")

                    writer.add_image(
                        f'Batch/Val/{current_depth}/{epoch}',
                        create_grid(
                            samples=c_pro_gan.gen(v_fixed_gan_input,
                                                  current_depth, alpha),
                            scale_factor=int(
                                np.power(2,
                                         c_pro_gan.depth - current_depth - 1)),
                            img_file=
                            None,  # if none we get the image grid returned
                        ),
                        i)
                # provide a loss feedback
                if i % int(total_batches + 1 / feedback_factor) == 0 or i == 1:
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print(
                        "Elapsed [%s]  batch: %d  d_loss: %f  g_loss: %f  kl_los: %f"
                        % (elapsed, i, dis_loss, gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    os.makedirs(log_dir, exist_ok=True)
                    log_file = os.path.join(
                        log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(
                            str(dis_loss) + "\t" + str(gen_loss) + "\t" +
                            str(kl_loss.item()) + "\n")

                    # create a grid of samples and save it
                    gen_img_file = os.path.join(
                        sample_dir, "gen_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".png")

                    create_grid(
                        samples=c_pro_gan.gen(fixed_gan_input, current_depth,
                                              alpha),
                        scale_factor=int(
                            np.power(2, c_pro_gan.depth - current_depth - 1)),
                        img_file=gen_img_file,
                    )

                # increment the ticker:
                ticker += 1
            writer.add_scalar(f"Epoch/Generator_Loss/{current_depth}",
                              np.mean(gen_losses), epoch)
            writer.add_scalar(f"Epoch/Discriminator_Loss/{current_depth}",
                              np.mean(dis_losses), epoch)
            writer.add_scalar(f"Epoch/KL_Loss/{current_depth}",
                              np.mean(kl_losses), epoch)

            writer.add_scalar(f"Epoch/Val/Generator_Loss/{current_depth}",
                              np.mean(val_gen_losses), epoch)
            writer.add_scalar(f"Epoch/Val/Discriminator_Loss/{current_depth}",
                              np.mean(val_dis_losses), epoch)
            writer.add_image(
                f'Epoch/{current_depth}',
                create_grid(
                    samples=c_pro_gan.gen(fixed_gan_input, current_depth,
                                          alpha),
                    scale_factor=int(
                        np.power(2, c_pro_gan.depth - current_depth - 1)),
                    img_file=None,  # if none we get the image grid returned
                ),
                epoch)
            writer.close()
            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                # save the Model
                encoder_save_file = os.path.join(
                    save_dir, "Encoder_" + str(current_depth) + ".pth")
                ca_save_file = os.path.join(
                    save_dir,
                    "Condition_Augmentor_" + str(current_depth) + ".pth")
                gen_save_file = os.path.join(
                    save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                dis_save_file = os.path.join(
                    save_dir, "GAN_DIS_" + str(current_depth) + ".pth")

                os.makedirs(save_dir, exist_ok=True)

                if encoder_optim is not None:
                    th.save(encoder.state_dict(), encoder_save_file, pickle)
                th.save(ca.state_dict(), ca_save_file, pickle)
                th.save(c_pro_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(c_pro_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
Example #2
0
def train_networks(encoder, ca, msg_gan, dataset, epochs,
                   encoder_optim, ca_optim, gen_optim, dis_optim, loss_fn, fade_in_percentage,
                   batch_sizes, start_depth, num_workers, feedback_factor,
                   log_dir, sample_dir, checkpoint_factor,
                   save_dir, use_matching_aware_dis=True):
    # required only for type checking
    from networks.TextEncoder import PretrainedEncoder
    from numpy import power

    # input assertions
    assert msg_gan.depth == len(batch_sizes), "batch_sizes not compatible with depth"
    assert msg_gan.depth == len(epochs), "epochs_sizes not compatible with depth"
    assert msg_gan.depth == len(fade_in_percentage), "fip_sizes not compatible with depth"

    # put all the Networks in training mode:
    ca.train()
    msg_gan.gen.train()
    msg_gan.dis.train()

    if not isinstance(encoder, PretrainedEncoder):
        encoder.train()

    print("Starting the training process ... ")

    # create fixed_input for debugging###################################################
    temp_data = dl.get_data_loader(dataset, batch_sizes[start_depth], num_workers=3)
    fixed_captions, fixed_real_images = iter(temp_data).next()
    fixed_embeddings = encoder(fixed_captions.to(device)).to(device)
    #fixed_embeddings = th.from_numpy(fixed_embeddings).to(device)

    fixed_c_not_hats, _, _ = ca(fixed_embeddings)

    fixed_noise = th.randn(len(fixed_captions),
                           msg_gan.latent_size - fixed_c_not_hats.shape[-1]).to(device)

    fixed_gan_input = th.cat((fixed_c_not_hats, fixed_noise), dim=-1)

    # create a global time counter
    global_time = time.time()

    # delete temp data loader:
    del temp_data
    ####################################################################################
    ####################################################################################
    for current_depth in range(start_depth, msg_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth], num_workers)

        ticker = 1

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))
            fader_point = int((fade_in_percentage[current_depth] / 100)
                              * epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                captions, images = batch
                images = images.to(device)
                extracted_batch_size = images.shape[0]
                if encoder_optim is not None:
                    captions = captions.to(device)

                #create a lst of downsampled images from the real images:
                images = [images] + [avg_pool2d(images, int(np.power(2, i)))
                                     for i in range(1, 7)]
                images = list(reversed(images))
                # perform text_work:
                embeddings = th.from_numpy(encoder(captions).cpu().detach().numpy()).to(device)
                if encoder_optim is None:
                    # detach the LSTM from backpropagation
                    embeddings = embeddings.detach()
                c_not_hats, mus, sigmas = ca(embeddings)

                z = th.randn(
                    extracted_batch_size,
                    msg_gan.latent_size - c_not_hats.shape[-1]
                ).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                # optimize the discriminator:
                dis_loss = msg_gan.optimize_discriminator(dis_optim, gan_input, images,
                                                            loss_fn)
                
                # optimize the generator:
                z = th.randn(
                    captions.shape[0] if isinstance(captions, th.Tensor) else len(captions),
                    msg_gan.latent_size - c_not_hats.shape[-1]
                ).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                if encoder_optim is not None:
                    encoder_optim.zero_grad()

                ca_optim.zero_grad()
                gen_loss = msg_gan.optimize_generator(gen_optim, gan_input, images,
                                                        loss_fn)
                
                # once the optimize_generator is called, it also sends gradients
                # to the Conditioning Augmenter and the TextEncoder. Hence the
                # zero_grad statements prior to the optimize_generator call
                # now perform optimization on those two as well
                # obtain the loss (KL divergence from ca_optim)
                kl_loss = th.mean(0.5 * th.sum((mus ** 2) + (sigmas ** 2)
                                               - th.log((sigmas ** 2)) - 1, dim=1))
                kl_loss.backward(retain_graph=True)
                ca_optim.step()
                if encoder_optim is not None:
                    encoder_optim.step()

                # provide a loss feedback
                if i % int(total_batches / feedback_factor) == 0 or i == 1:
                    elapsed = time.time() - global_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    print("Elapsed [%s]  batch: %d  d_loss: %f  g_loss: %f  kl_los: %f"
                          % (elapsed, i, dis_loss, gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    os.makedirs(log_dir, exist_ok=True)
                    log_file = os.path.join(log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(str(dis_loss) + "\t" + str(gen_loss)
                                  + "\t" + str(kl_loss.item()) + "\n")

                    # create a grid of samples and save it
                    """gen_img_file = os.path.join(sample_dir, "gen_" + str(current_depth) +
                                                "_" + str(epoch) + "_" +
                                                str(i) + ".png")"""
                    # create a grid of samples and save it
                    reses = [str(int(np.power(2, dep))) + "_x_"
                             + str(int(np.power(2, dep)))
                             for dep in range(2, 9)]
                    
                    #print(current_depth)
                    #print(reses)
                    gen_img_files = [os.path.join(sample_dir, res, "gen_" +
                                                  str(epoch) + "_" +
                                                  str(i) + ".png")
                                     for res in reses]
                    
                    os.makedirs(sample_dir, exist_ok=True)

                    for gen_img_file in gen_img_files:
                        os.makedirs(os.path.dirname(gen_img_file), exist_ok=True)

                    dis_optim.zero_grad()
                    gen_optim.zero_grad()

                    with th.no_grad():

                        create_grid(samples=msg_gan.gen(fixed_gan_input) if not True 
                            else msg_gan.gen_shadow(fixed_gan_input),
                            img_files=gen_img_files
                            )

                # increment the ticker:
                ticker += 1

            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                # save the Model
                encoder_save_file = os.path.join(save_dir, "Encoder_" +
                                                 str(current_depth) + ".pth")
                ca_save_file = os.path.join(save_dir, "Condition_Augmentor_" +
                                            str(current_depth) + ".pth")
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" +
                                             str(current_depth) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" +
                                             str(current_depth) + ".pth")

                os.makedirs(save_dir, exist_ok=True)

                if encoder_optim is not None:
                    th.save(encoder.state_dict(), encoder_save_file, pickle)
                th.save(ca.state_dict(), ca_save_file, pickle)
                th.save(msg_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(msg_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
Example #3
0
def train_networks(encoder,
                   ca,
                   c_pro_gan,
                   dataset,
                   epochs,
                   encoder_optim,
                   ca_optim,
                   fade_in_percentage,
                   batch_sizes,
                   start_depth,
                   num_workers,
                   feedback_factor,
                   log_dir,
                   sample_dir,
                   checkpoint_factor,
                   save_dir,
                   use_matching_aware_dis=True):
    assert c_pro_gan.depth == len(
        batch_sizes), "batch_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        epochs), "epochs_sizes not compatible with depth"
    assert c_pro_gan.depth == len(
        fade_in_percentage), "fip_sizes not compatible with depth"

    print("Starting the training process ... ")
    for current_depth in range(start_depth, c_pro_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth],
                                  num_workers)

        ticker = 1

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))
            fader_point = int((fade_in_percentage[current_depth] / 100) *
                              epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                captions, images = batch

                if encoder_optim is not None:
                    captions = captions.to(device)

                images = images.to(device)

                # perform text_work:
                embeddings = encoder(captions)
                if not isinstance(embeddings, th.Tensor):
                    embeddings = th.tensor(embeddings).to(device)
                c_not_hats, mus, sigmas = ca(embeddings)

                z = th.randn(
                    captions.shape[0]
                    if isinstance(captions, th.Tensor) else len(captions),
                    c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                # optimize the discriminator:
                dis_loss = c_pro_gan.optimize_discriminator(
                    gan_input, images, embeddings, current_depth, alpha,
                    use_matching_aware_dis)

                # optimize the generator:
                z = th.randn(
                    captions.shape[0]
                    if isinstance(captions, th.Tensor) else len(captions),
                    c_pro_gan.latent_size - c_not_hats.shape[-1]).to(device)

                gan_input = th.cat((c_not_hats, z), dim=-1)

                if encoder_optim is not None:
                    encoder_optim.zero_grad()

                ca_optim.zero_grad()
                gen_loss = c_pro_gan.optimize_generator(
                    gan_input, embeddings, current_depth, alpha)

                # once the optimize_generator is called, it also sends gradients
                # to the Conditioning Augmenter and the TextEncoder. Hence the
                # zero_grad statements prior to the optimize_generator call
                # now perform optimization on those two as well
                # obtain the loss (KL divergence from ca_optim)
                kl_loss = th.mean(0.5 * th.sum(
                    (mus**2) + (sigmas**2) - th.log((sigmas**2)) - 1, dim=1))
                kl_loss.backward()
                ca_optim.step()
                if encoder_optim is not None:
                    encoder_optim.step()

                # provide a loss feedback
                if i % int(total_batches / feedback_factor) == 0 or i == 1:
                    print("batch: %d  d_loss: %f  g_loss: %f  kl_los: %f" %
                          (i, dis_loss, gen_loss, kl_loss.item()))

                    # also write the losses to the log file:
                    log_file = os.path.join(
                        log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(
                            str(dis_loss) + "\t" + str(gen_loss) + "\t" +
                            str(kl_loss.item()) + "\n")

                    # create a grid of samples and save it
                    gen_img_file = os.path.join(
                        sample_dir, "gen_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".png")
                    orig_img_file = os.path.join(
                        sample_dir, "orig_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".png")
                    description_file = os.path.join(
                        sample_dir, "desc_" + str(current_depth) + "_" +
                        str(epoch) + "_" + str(i) + ".txt")
                    create_grid(
                        samples=c_pro_gan.gen(gan_input, current_depth, alpha),
                        scale_factor=int(
                            np.power(2, c_pro_gan.depth - current_depth - 1)),
                        img_file=gen_img_file,
                        width=int(np.sqrt(batch_sizes[current_depth])),
                    )

                    create_grid(samples=images,
                                scale_factor=int(
                                    np.power(
                                        2,
                                        c_pro_gan.depth - current_depth - 1)),
                                img_file=orig_img_file,
                                width=int(np.sqrt(batch_sizes[current_depth])),
                                real_imgs=True)

                    create_descriptions_file(description_file, captions,
                                             dataset)

                # increment the ticker:
                ticker += 1

            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                # save the Model
                encoder_save_file = os.path.join(
                    save_dir, "Encoder_" + str(current_depth) + ".pth")
                ca_save_file = os.path.join(
                    save_dir,
                    "Condition_Augmentor_" + str(current_depth) + ".pth")
                gen_save_file = os.path.join(
                    save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                dis_save_file = os.path.join(
                    save_dir, "GAN_DIS_" + str(current_depth) + ".pth")

                if encoder_optim is not None:
                    th.save(encoder.state_dict(), encoder_save_file, pickle)
                th.save(ca.state_dict(), ca_save_file, pickle)
                th.save(c_pro_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(c_pro_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")
Example #4
0
def train_networks(pro_gan, dataset, epochs,
                   fade_in_percentage, batch_sizes,
                   start_depth, num_workers, feedback_factor,
                   log_dir, sample_dir, checkpoint_factor,
                   save_dir):

    assert pro_gan.depth == len(batch_sizes), "batch_sizes not compatible with depth"

    print("Starting the training process ... ")
    for current_depth in range(start_depth, pro_gan.depth):

        print("\n\nCurrently working on Depth: ", current_depth)
        current_res = np.power(2, current_depth + 2)
        print("Current resolution: %d x %d" % (current_res, current_res))

        data = dl.get_data_loader(dataset, batch_sizes[current_depth], num_workers)
        ticker = 1

        for epoch in range(1, epochs[current_depth] + 1):
            start = timeit.default_timer()  # record time at the start of epoch

            print("\nEpoch: %d" % epoch)
            total_batches = len(iter(data))

            fader_point = int((fade_in_percentage[current_depth] / 100)
                              * epochs[current_depth] * total_batches)

            for (i, batch) in enumerate(data, 1):
                # calculate the alpha for fading in the layers
                alpha = ticker / fader_point if ticker <= fader_point else 1

                # extract current batch of data for training
                images = batch.to(device)

                gan_input = th.randn(images.shape[0], pro_gan.gen.latent_size).to(pro_gan.device)

                # optimize the discriminator:
                dis_loss = pro_gan.optimize_discriminator(gan_input, images,
                                                          current_depth, alpha)

                # optimize the generator:
                gan_input = th.randn(images.shape[0], pro_gan.gen.latent_size).to(pro_gan.device)
                gen_loss = pro_gan.optimize_generator(gan_input, current_depth, alpha)

                # provide a loss feedback
                if i % int(total_batches / feedback_factor) == 0 or i == 1:
                    print("batch: %d  d_loss: %f  g_loss: %f" % (i, dis_loss, gen_loss))

                    # also write the losses to the log file:
                    log_file = os.path.join(log_dir, "loss_" + str(current_depth) + ".log")
                    with open(log_file, "a") as log:
                        log.write(str(dis_loss) + "\t" + str(gen_loss) + "\n")

                    # create a grid of samples and save it
                    gen_img_file = os.path.join(sample_dir, "gen_" + str(current_depth) +
                                                "_" + str(epoch) + "_" +
                                                str(i) + ".png")
                    create_grid(
                        samples=pro_gan.gen(
                            gan_input,
                            current_depth,
                            alpha
                        ),
                        scale_factor=int(np.power(2, pro_gan.depth - current_depth - 1)),
                        img_file=gen_img_file,
                        width=int(np.sqrt(batch_sizes[current_depth])),
                    )

                # increment the alpha ticker
                ticker += 1

            stop = timeit.default_timer()
            print("Time taken for epoch: %.3f secs" % (stop - start))

            if epoch % checkpoint_factor == 0 or epoch == 0:
                gen_save_file = os.path.join(save_dir, "GAN_GEN_" + str(current_depth) + ".pth")
                dis_save_file = os.path.join(save_dir, "GAN_DIS_" + str(current_depth) + ".pth")

                th.save(pro_gan.gen.state_dict(), gen_save_file, pickle)
                th.save(pro_gan.dis.state_dict(), dis_save_file, pickle)

    print("Training completed ...")