Exemplo n.º 1
0
def create_grid(samples, img_files, depth=7):
    """
    utility function to create a grid of GAN samples
    :param samples: generated samples for storing
    :param scale_factor: factor for upscaling the image
    :param img_file: name of file to write
    :param real_imgs: turn off the scaling of images
    :return: None (saves a file)
    """
    from torchvision.utils import save_image
    from torch.nn.functional import interpolate
    from numpy import sqrt, power
    from MSG_GAN.GAN import Generator

        # dynamically adjust the colour of the images
    samples = [Generator.adjust_dynamic_range(sample) for sample in samples]

        # resize the samples to have same resolution:
    for i in range(len(samples)):
        samples[i] = interpolate(samples[i],
                                     scale_factor=power(2,
                                                        6 - i))
        # save the images:
    for sample, img_file in zip(samples, img_files):
        save_image(sample, img_file, nrow=int(sqrt(sample.shape[0])),
                       normalize=True, scale_each=True, padding=0)
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    print("Creating generator object ...")
    # create the generator object
    gen = th.nn.DataParallel(
        Generator(depth=args.depth, latent_size=args.latent_size))

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    gen.load_state_dict(th.load(args.generator_file))

    # path for saving the files:
    save_path = args.out_dir

    print("Generating scale synchronized images ...")
    for img_num in tqdm(range(1, args.num_samples + 1)):
        # generate the images:
        with th.no_grad():
            points = th.randn(1, args.latent_size)
            points = (points / points.norm()) * sqrt(args.latent_size)
            ss_images = gen(points)

        # colour adjust the images
        ss_images = [adjust_dynamic_range(ss_image) for ss_image in ss_images]

        # resize the images:
        ss_images = progressive_upscaling(ss_images)

        # reverse the ss_images
        # ss_images = list(reversed(ss_images))

        # squeeze the batch dimension from each image
        ss_images = list(map(lambda x: th.squeeze(x, dim=0), ss_images))

        # make a grid out of them
        num_cols = int(ceil(sqrt(len(ss_images)))) if args.num_columns is None \
            else args.num_columns
        if num_cols == 1:
            # tower image condition:
            ss_images = list(reversed(ss_images))

        ss_image = make_grid(ss_images,
                             nrow=num_cols,
                             normalize=True,
                             scale_each=True)

        # save the ss_image in the directory
        imsave(os.path.join(save_path,
                            str(img_num) + ".png"),
               ss_image.permute(1, 2, 0).cpu())

    print("Generated %d images at %s" % (args.num_samples, save_path))
Exemplo n.º 3
0
def main(args):
    """
    Main function of the script
    :param args: parsed commandline arguments
    :return: None
    """
    from MSG_GAN.GAN import Generator

    # create generator object:
    print("Creating a generator object ...")
    generator = th.nn.DataParallel(
        Generator(depth=args.depth, latent_size=args.latent_size).to(device))

    # load the trained generator weights
    print("loading the trained generator weights ...")
    generator.load_state_dict(th.load(args.generator_file))

    # total_frames in the video:
    total_frames = int(args.time * args.fps)

    # Let's create the animation video from the latent space interpolation
    # all latent vectors:
    all_latents = th.randn(total_frames,
                           args.latent_size).to(device) * args.std
    all_latents = gaussian_filter(all_latents.cpu(), [args.fps, 0])
    all_latents = th.from_numpy(all_latents)
    all_latents = all_latents / all_latents.norm(dim=-1, keepdim=True) \
                  * (sqrt(args.latent_size))

    # create output directory
    os.makedirs(args.out_dir, exist_ok=True)

    global_frame_counter = 1
    # Run the main loop for the interpolation:
    print("Generating the video frames ...")
    for latent in tqdm(all_latents):
        latent = th.unsqueeze(latent, dim=0)

        # generate the image for this point:
        img = get_image(generator, latent)

        # save the image:
        plt.imsave(
            os.path.join(args.out_dir,
                         str(global_frame_counter) + ".png"), img)

        # increment the counter:
        global_frame_counter += 1

    # video frames have been generated
    print("Video frames have been generated at:", args.out_dir)
Exemplo n.º 4
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """
    from MSG_GAN.GAN import Generator, MSG_GAN
    from torch.nn import DataParallel

    # create a generator:
    msg_gan_generator = Generator(depth=args.depth,
                                  latent_size=args.latent_size).to(device)

    if device == th.device("cuda"):
        msg_gan_generator = DataParallel(msg_gan_generator)

    if args.generator_file is not None:
        # load the weights into generator
        msg_gan_generator.load_state_dict(th.load(args.generator_file))

    print("Loaded Generator Configuration: ")
    print(msg_gan_generator)

    # generate all the samples in a list of lists:
    samples = []  # start with an empty list
    for _ in range(args.num_samples):
        gen_samples = msg_gan_generator(th.randn(1, args.latent_size))
        samples.append(gen_samples)

        if args.show_samples:
            for gen_sample in gen_samples:
                plt.figure()
                plt.imshow(
                    th.squeeze(gen_sample.detach()).permute(1, 2, 0) / 2 + 0.5)
            plt.show()

    # create a grid of the generated samples:
    file_names = []  # initialize to empty list
    for res_val in range(args.depth):
        res_dim = np.power(2, res_val + 2)
        file_name = os.path.join(args.output_dir,
                                 str(res_dim) + "_" + str(res_dim) + ".png")
        file_names.append(file_name)

    images = list(map(lambda x: th.cat(x, dim=0), zip(*samples)))
    MSG_GAN.create_grid(images, file_names)

    print("samples have been generated. Please check:", args.output_dir)
Exemplo n.º 5
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    # load the model for the demo
    gen = th.nn.DataParallel(
        Generator(
            depth=args.depth,
            latent_size=args.latent_size))
    gen.load_state_dict(th.load(args.generator_file, map_location=str(device)))

    # generate the set of points:
    total_frames = args.num_points * args.transition_points
    all_latents = th.randn(total_frames, args.latent_size).to(device)
    all_latents = th.from_numpy(
        gaussian_filter(
            all_latents.cpu(),
            [args.smoothing * args.transition_points, 0], mode="wrap"))
    all_latents = (all_latents /
                   all_latents.norm(dim=-1, keepdim=True)) * sqrt(args.latent_size)

    start_point = th.unsqueeze(all_latents[0], dim=0)
    points = all_latents[1:]

    fig, ax = plt.subplots()
    plt.axis("off")
    shower = plt.imshow(get_image(gen, start_point))

    def init():
        return shower,

    def update(point):
        shower.set_data(get_image(gen, th.unsqueeze(point, dim=0)))
        return shower,

    # define the animation function
    ani = FuncAnimation(fig, update, frames=points,
                        init_func=init)
    plt.show(ani)
Exemplo n.º 6
0
def main(args):
    """
    Main function for the script
    :param args: parsed command line arguments
    :return: None
    """

    print("Creating generator object ...")
    # create the generator object
    gen = th.nn.DataParallel(
        Generator(depth=args.depth, latent_size=args.latent_size))

    print("Loading the generator weights from:", args.generator_file)
    # load the weights into it
    gen.load_state_dict(th.load(args.generator_file))

    # path for saving the files:
    save_path = args.out_dir

    print("Generating scale synchronized images ...")
    points = th.randn(args.num_samples, args.latent_size)
    for img_num in tqdm(range(1, args.num_samples + 1)):
        # generate the images:
        with th.no_grad():
            point = th.unsqueeze(points[img_num - 1], 0)
            point = (point / point.norm()) * (args.latent_size**0.5)
            ss_images = gen(point)

        # resize the images:
        ss_images = [adjust_dynamic_range(ss_image) for ss_image in ss_images]
        ss_images = progressive_upscaling(ss_images)
        ss_image = ss_images[args.out_depth]

        # save the ss_image in the directory
        imsave(os.path.join(save_path,
                            str(img_num) + ".png"),
               ss_image.squeeze(0).permute(1, 2, 0).cpu())

    print("Generated %d images at %s" % (args.num_samples, save_path))
Exemplo n.º 7
0
# ==========================================================================
# Tweakable parameters
# ==========================================================================
generator_file_path = "models/celebahq_testing/GAN_GEN_SHADOW_720.pth"
depth = 8
latent_size = 512
num_points = 30
transition_points = 15
# ==========================================================================

# create the device for running the demo:
device = th.device("cuda" if th.cuda.is_available() else "cpu")

# load the model for the demo
gen = th.nn.DataParallel(Generator(depth=depth + 1, latent_size=latent_size))
gen.load_state_dict(th.load(generator_file_path, map_location=str(device)))


# function to generate an image given a latent_point
def get_image(point):
    images = list(map(lambda x: x.detach(), gen(point)))
    images = progressive_upscaling(images)
    images = list(map(lambda x: x.squeeze(dim=0), images))
    image = make_grid(
        images,
        nrow=int(ceil(sqrt(len(images)))),
        normalize=True,
        scale_each=True
    )
    return image.cpu().numpy().transpose(1, 2, 0)
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
            np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * scale + bias
    return th.clamp(data, min=0, max=1)


# go over all the models and calculate it's fid by generating images from that model
for epoch in range((start // step), (total_range // step) + 1):
    epoch = 1 if epoch == 0 else epoch * step
    model_file = "GAN_GEN_SHADOW_" + str(epoch) + ".pth"
    model_file_path = os.path.join(models_path, model_file)

    # create a new generator object
    gen = th.nn.DataParallel(
        Generator(depth=depth, latent_size=latent_size).to(device))

    # load these weights into the model
    gen.load_state_dict(th.load(model_file_path))

    # empty the temp directory and make it to ensure it exists
    if os.path.isdir(temp_fid_path):
        rmtree(temp_fid_path)
    os.makedirs(temp_fid_path, exist_ok=True)

    print("\n\nLoaded model:", epoch)
    print("weights loaded from:", model_file_path)
    print("generating %d images using this model ..." % gen_fid_images)
    pbar = tqdm(total=gen_fid_images)
    generated_images = 0
    while generated_images < gen_fid_images:
Exemplo n.º 9
0
def main(args):
    """
    Main function of the script
    :param args: parsed commandline arguments
    :return: None
    """
    from MSG_GAN.GAN import Generator

    # create generator object:
    print("Creating a generator object ...")
    generator = th.nn.DataParallel(
        Generator(depth=args.depth, latent_size=args.latent_size).to(device))

    # load the trained generator weights
    print("loading the trained generator weights ...")
    generator.load_state_dict(th.load(args.generator_file))

    # total_frames in the video:
    total_time_for_one_transition = args.traversal_time + args.static_time
    total_frames_for_one_transition = (total_time_for_one_transition *
                                       args.fps)
    number_of_transitions = int(
        (args.time * 60) / total_time_for_one_transition)
    total_frames = int(number_of_transitions * total_frames_for_one_transition)

    # Let's create the animation video from the latent space interpolation
    # I save the frames required for making the video here
    point_1 = th.randn(1, args.latent_size).to(device) * args.std

    # create output directory
    os.makedirs(args.out_dir, exist_ok=True)

    # Run the main loop for the interpolation:
    global_frame_counter = 1  # counts number of frames
    while global_frame_counter <= total_frames:
        point_2 = th.randn(1, args.latent_size).to(device) * args.std
        direction = point_2 - point_1

        # create the points for images in this space:
        number_of_points = int(args.traversal_time * args.fps)
        for i in range(number_of_points):
            point = point_1 + ((direction / number_of_points) * i)

            # generate the image for this point:
            generator.load_state_dict(th.load(args.generator_file))
            img = th.squeeze(generator(point)[-1].detach(), dim=0).permute(
                1, 2, 0) / 2 + 0.5

            # save the image:
            plt.imsave(
                os.path.join(args.out_dir,
                             str(global_frame_counter) + ".png"), img)

            # increment the counter:
            global_frame_counter += 1

        # at point_2, now add static frames:
        generator.load_state_dict(th.load(args.generator_file))
        img = th.squeeze(generator(point_2)[-1].detach(), dim=0).permute(
            1, 2, 0) / 2 + 0.5

        # now save the same image a number of times:
        for _ in range(args.static_time * args.fps):
            plt.imsave(
                os.path.join(args.out_dir,
                             str(global_frame_counter) + ".png"), img)
            global_frame_counter += 1

        # set the point_1 := point_2
        point_1 = point_2

        print("Generated %d frames ..." % global_frame_counter)

    # video frames have been generated
    print("Video frames have been generated at:", args.out_dir)