Пример #1
0
def extract(tfrecord_dir, output_dir):
    print('Loading dataset "%s"' % tfrecord_dir)
    tflib.init_tf({'gpu_options.allow_growth': True})
    dset = dataset.TFRecordDataset(tfrecord_dir,
                                   max_label_size=0,
                                   repeat=False,
                                   shuffle_mb=0)
    tflib.init_uninitialized_vars()

    print('Extracting images to "%s"' % output_dir)
    if not os.path.isdir(output_dir):
        os.makedirs(output_dir)
    idx = 0
    while True:
        if idx % 10 == 0:
            print('%d\r' % idx, end='', flush=True)
        try:
            images, _labels = dset.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            break
        if images.shape[1] == 1:
            img = PIL.Image.fromarray(images[0][0], 'L')
        else:
            img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB')
        img.save(os.path.join(output_dir, 'img%08d.png' % idx))
        idx += 1
    print('Extracted %d images.' % idx)
Пример #2
0
def main_conditional():
    # Initialize TensorFlow
    tflib.init_tf()

    # Load pre-trained network
    dir = 'results/00004-sgan-cifar10-1gpu-cond/'
    fn = 'network-snapshot-010372.pkl'
    _G, _D, Gs = pickle.load(open(os.path.join(dir, fn), 'rb'))

    # Print network details
    Gs.print_layers()

    # rnd = np.random.RandomState(10)

    # Initialize conditioning
    conditioning = np.eye(num_classes)

    for i, rnd in enumerate([np.random.RandomState(i) for i in np.arange(20)]):

        # Pick latent vector.
        latents = rnd.randn(num_classes, Gs.input_shape[1])

        # Generate image.
        images = Gs.run(latents,
                        conditioning,
                        truncation_psi=0.7,
                        randomize_noise=True,
                        output_transform=fmt)
        images = images.reshape(32 * 10, 32, 3)

        # Save image.
        png_filename = os.path.join(dir, 'example_{}.png'.format(i))
        PIL.Image.fromarray(images, 'RGB').save(png_filename)
Пример #3
0
def display(tfrecord_dir):
    print('Loading dataset "%s"' % tfrecord_dir)
    tflib.init_tf({'gpu_options.allow_growth': True})
    dset = dataset.TFRecordDataset(tfrecord_dir,
                                   max_label_size='full',
                                   repeat=False,
                                   shuffle_mb=0)
    tflib.init_uninitialized_vars()
    import cv2  # pip install opencv-python

    idx = 0
    while True:
        try:
            images, labels = dset.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            break
        if idx == 0:
            print('Displaying images')
            cv2.namedWindow('dataset_tool')
            print('Press SPACE or ENTER to advance, ESC to exit')
        print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist()))
        cv2.imshow('dataset_tool', images[0].transpose(
            1, 2, 0)[:, :, ::-1])  # CHW => HWC, RGB => BGR
        idx += 1
        if cv2.waitKey() == 27:
            break
    print('\nDisplayed %d images.' % idx)
Пример #4
0
def main():
    # Initialize TensorFlow.
    tflib.init_tf()

    # Load pre-trained network.
    url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'  # karras2019stylegan-ffhq-1024x1024.pkl
    with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
        _G, _D, Gs = pickle.load(f)
        # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
        # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
        # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.

    # Print network details.
    Gs.print_layers()

    # Pick latent vector.
    latents = rnd.randn(1, Gs.input_shape[1])

    # Generate image.
    images = Gs.run(latents,
                    None,
                    truncation_psi=0.7,
                    randomize_noise=True,
                    output_transform=fmt)

    # Save image.
    os.makedirs(config.result_dir, exist_ok=True)
    png_filename = os.path.join(config.result_dir, 'example.png')
    PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
Пример #5
0
def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl))
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
Пример #6
0
def run_snapshot(submit_config, metric_args, run_id, snapshot):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot))
    run_dir = misc.locate_run_dir(run_id)
    network_pkl = misc.locate_network_pkl(run_dir, snapshot)
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
Пример #7
0
def run_all_snapshots(submit_config, metric_args, run_id):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id))
    run_dir = misc.locate_run_dir(run_id)
    network_pkls = misc.list_network_pkls(run_dir)
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    for idx, network_pkl in enumerate(network_pkls):
        ctx.update('', idx, len(network_pkls))
        metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
Пример #8
0
def get_generator(batch_size=1):
    tiled_dlatent, randomize_noise = False, False
    clipping_threshold = 2
    dlatent_avg = ''

    tflib.init_tf()
    with dnnlib.util.open_url(URL_FFHQ, cache_dir=config.cache_dir) as f:
        generator_network, discriminator_network, Gs_network = pickle.load(f)
        del discriminator_network, generator_network
    generator = Generator(Gs_network, batch_size=batch_size, clipping_threshold=clipping_threshold,
                          tiled_dlatent=tiled_dlatent, randomize_noise=randomize_noise)
    if (dlatent_avg != ''):
        generator.set_dlatent_avg(np.load(dlatent_avg))
    return generator, Gs_network
Пример #9
0
def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
    max_label_size = 0 if ignore_labels else 'full'
    print('Loading dataset "%s"' % tfrecord_dir_a)
    tflib.init_tf({'gpu_options.allow_growth': True})
    dset_a = dataset.TFRecordDataset(tfrecord_dir_a,
                                     max_label_size=max_label_size,
                                     repeat=False,
                                     shuffle_mb=0)
    print('Loading dataset "%s"' % tfrecord_dir_b)
    dset_b = dataset.TFRecordDataset(tfrecord_dir_b,
                                     max_label_size=max_label_size,
                                     repeat=False,
                                     shuffle_mb=0)
    tflib.init_uninitialized_vars()

    print('Comparing datasets')
    idx = 0
    identical_images = 0
    identical_labels = 0
    while True:
        if idx % 100 == 0:
            print('%d\r' % idx, end='', flush=True)
        try:
            images_a, labels_a = dset_a.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            images_a, labels_a = None, None
        try:
            images_b, labels_b = dset_b.get_minibatch_np(1)
        except tf.errors.OutOfRangeError:
            images_b, labels_b = None, None
        if images_a is None or images_b is None:
            if images_a is not None or images_b is not None:
                print('Datasets contain different number of images')
            break
        if images_a.shape == images_b.shape and np.all(images_a == images_b):
            identical_images += 1
        else:
            print('Image %d is different' % idx)
        if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):
            identical_labels += 1
        else:
            print('Label %d is different' % idx)
        idx += 1
    print('Identical images: %d / %d' % (identical_images, idx))
    if not ignore_labels:
        print('Identical labels: %d / %d' % (identical_labels, idx))
Пример #10
0
def main_textual():
    # Initialize Tensorflow
    tflib.init_tf()

    dir = 'results/00015-sgancoco_train-1gpu-cond'
    fn = 'network-snapshot-025000.pkl'
    _, _, Gs = pickle.load(open(os.path.join(dir, fn), 'rb'))

    # Print network details
    Gs.print_layers()
    embeddings = np.load('datasets/coco_test/coco_test-rxx.labels')
    fns = np.load('datasets/coco_test/fns.npy')

    # Use only 1 description (instead of all 5, to compare to attnGAN)
    embeddings = embeddings[0::5]
    fns = fns[0::5]

    for i, rnd in enumerate(
        [np.random.RandomState(i) for i in np.arange(embeddings.shape[0])]):

        latent = rnd.randn(1, Gs.input_shape[1])

        emb = embeddings[i].reshape(1, -1)

        image = Gs.run(latent,
                       emb,
                       truncation_psi=0.8,
                       randomize_noise=True,
                       output_transform=fmt)

        image = image.reshape(256, 256, 3)

        png_filename = os.path.join(dir, 'examples/{}.png'.format(fns[i]))

        image = Image.fromarray(image)
        image.save(png_filename)
Пример #11
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id='latest',  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        # Load pre-trained
        if resume_run_id is not None:
            if resume_run_id == 'latest':
                network_pkl, resume_kimg = misc.locate_latest_pkl()
                print('Loading networks from "%s"...' % network_pkl)
                G, D, Gs = misc.load_pkl(network_pkl)

            elif resume_run_id == 'restore_partial':
                print('Restore partially...')
                # Initialize networks
                G = tflib.Network('G',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **G_args)
                D = tflib.Network('D',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **D_args)
                Gs = G.clone('Gs')

                # Load pre-trained networks
                assert restore_partial_fn != None
                G_partial, D_partial, Gs_partial = pickle.load(
                    open(restore_partial_fn, 'rb'))

                # Restore (subset of) pre-trained weights
                # (only parameters that match both name and shape)
                G.copy_compatible_trainables_from(G_partial)
                D.copy_compatible_trainables_from(D_partial)
                Gs.copy_compatible_trainables_from(Gs_partial)

            else:
                network_pkl = misc.locate_network_pkl(resume_run_id,
                                                      resume_snapshot)
                print('Loading networks from "%s"...' % network_pkl)
                G, D, Gs = misc.load_pkl(network_pkl)

        # Start from scratch
        else:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Пример #12
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'Find latent representation of reference images using perceptual losses',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('src_dir', help='Directory with images for encoding')
    parser.add_argument('generated_images_dir',
                        help='Directory for storing generated images')
    parser.add_argument('dlatent_dir',
                        help='Directory for storing dlatent representations')
    parser.add_argument('--data_dir',
                        default='data',
                        help='Directory for storing optional models')
    parser.add_argument('--mask_dir',
                        default='masks',
                        help='Directory for storing optional masks')
    parser.add_argument('--load_last',
                        default='',
                        help='Start with embeddings from directory')
    parser.add_argument(
        '--dlatent_avg',
        default='',
        help=
        'Use dlatent from file specified here for truncation instead of dlatent_avg from Gs'
    )
    parser.add_argument(
        '--model_url',
        default=
        'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ',
        help='Fetch a StyleGAN model to train on from this URL'
    )  # karras2019stylegan-ffhq-1024x1024.pkl
    parser.add_argument('--model_res',
                        default=1024,
                        help='The dimension of images in the StyleGAN model',
                        type=int)
    parser.add_argument('--batch_size',
                        default=1,
                        help='Batch size for generator and perceptual model',
                        type=int)

    # Perceptual model params
    parser.add_argument('--image_size',
                        default=256,
                        help='Size of images for perceptual model',
                        type=int)
    parser.add_argument('--resnet_image_size',
                        default=256,
                        help='Size of images for the Resnet model',
                        type=int)
    parser.add_argument('--lr',
                        default=0.02,
                        help='Learning rate for perceptual model',
                        type=float)
    parser.add_argument('--decay_rate',
                        default=0.9,
                        help='Decay rate for learning rate',
                        type=float)
    parser.add_argument('--iterations',
                        default=100,
                        help='Number of optimization steps for each batch',
                        type=int)
    parser.add_argument(
        '--decay_steps',
        default=10,
        help='Decay steps for learning rate decay (as a percent of iterations)',
        type=float)
    parser.add_argument(
        '--load_effnet',
        default='data/finetuned_effnet.h5',
        help='Model to load for EfficientNet approximation of dlatents')
    parser.add_argument(
        '--load_resnet',
        default='data/finetuned_resnet.h5',
        help='Model to load for ResNet approximation of dlatents')

    # Loss function options
    parser.add_argument(
        '--use_vgg_loss',
        default=0.4,
        help='Use VGG perceptual loss; 0 to disable, > 0 to scale.',
        type=float)
    parser.add_argument('--use_vgg_layer',
                        default=9,
                        help='Pick which VGG layer to use.',
                        type=int)
    parser.add_argument(
        '--use_pixel_loss',
        default=1.5,
        help='Use logcosh image pixel loss; 0 to disable, > 0 to scale.',
        type=float)
    parser.add_argument(
        '--use_mssim_loss',
        default=100,
        help='Use MS-SIM perceptual loss; 0 to disable, > 0 to scale.',
        type=float)
    parser.add_argument(
        '--use_lpips_loss',
        default=100,
        help='Use LPIPS perceptual loss; 0 to disable, > 0 to scale.',
        type=float)
    parser.add_argument(
        '--use_l1_penalty',
        default=1,
        help='Use L1 penalty on latents; 0 to disable, > 0 to scale.',
        type=float)

    # Generator params
    parser.add_argument('--randomize_noise',
                        default=False,
                        help='Add noise to dlatents during optimization',
                        type=bool)
    parser.add_argument(
        '--tile_dlatents',
        default=False,
        help='Tile dlatents to use a single vector at each scale',
        type=bool)
    parser.add_argument(
        '--clipping_threshold',
        default=2.0,
        help='Stochastic clipping of gradient values outside of this threshold',
        type=float)

    # Masking params
    parser.add_argument('--load_mask',
                        default=False,
                        help='Load segmentation masks',
                        type=bool)
    parser.add_argument(
        '--face_mask',
        default=False,
        help='Generate a mask for predicting only the face area',
        type=bool)
    parser.add_argument(
        '--use_grabcut',
        default=True,
        help=
        'Use grabcut algorithm on the face mask to better segment the foreground',
        type=bool)
    parser.add_argument(
        '--scale_mask',
        default=1.5,
        help='Look over a wider section of foreground for grabcut',
        type=float)

    # Video params
    parser.add_argument('--video_dir',
                        default='videos',
                        help='Directory for storing training videos')
    parser.add_argument('--output_video',
                        default=False,
                        help='Generate videos of the optimization process',
                        type=bool)
    parser.add_argument('--video_codec',
                        default='MJPG',
                        help='FOURCC-supported video codec name')
    parser.add_argument('--video_frame_rate',
                        default=24,
                        help='Video frames per second',
                        type=int)
    parser.add_argument('--video_size',
                        default=512,
                        help='Video size in pixels',
                        type=int)
    parser.add_argument(
        '--video_skip',
        default=1,
        help='Only write every n frames (1 = write every frame)',
        type=int)

    args, other_args = parser.parse_known_args()

    args.decay_steps *= 0.01 * args.iterations  # Calculate steps as a percent of total iterations

    if args.output_video:
        import cv2
        synthesis_kwargs = dict(output_transform=dict(
            func=tflib.convert_images_to_uint8, nchw_to_nhwc=False),
                                minibatch_size=args.batch_size)

    ref_images = [
        os.path.join(args.src_dir, x) for x in os.listdir(args.src_dir)
    ]
    ref_images = list(filter(os.path.isfile, ref_images))

    if len(ref_images) == 0:
        raise Exception('%s is empty' % args.src_dir)

    os.makedirs(args.data_dir, exist_ok=True)
    os.makedirs(args.mask_dir, exist_ok=True)
    os.makedirs(args.generated_images_dir, exist_ok=True)
    os.makedirs(args.dlatent_dir, exist_ok=True)
    os.makedirs(args.video_dir, exist_ok=True)

    # Initialize generator and perceptual model
    tflib.init_tf()
    with familyGan.stylegan_encoder.dnnlib.util.open_url(
            args.model_url,
            cache_dir=familyGan.stylegan_encoder.config.cache_dir) as f:
        generator_network, discriminator_network, Gs_network = pickle.load(f)

    generator = Generator(Gs_network,
                          args.batch_size,
                          clipping_threshold=args.clipping_threshold,
                          tiled_dlatent=args.tile_dlatents,
                          model_res=args.model_res,
                          randomize_noise=args.randomize_noise)
    if (args.dlatent_avg != ''):
        generator.set_dlatent_avg(np.load(args.dlatent_avg))

    perc_model = None
    if (args.use_lpips_loss > 0.00000001):
        with familyGan.stylegan_encoder.dnnlib.util.open_url(
                'https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2',
                cache_dir=familyGan.stylegan_encoder.config.cache_dir) as f:
            perc_model = pickle.load(f)
    perceptual_model = PerceptualModel(args,
                                       perc_model=perc_model,
                                       batch_size=args.batch_size)
    perceptual_model.build_perceptual_model(generator)

    ff_model = None

    # Optimize (only) dlatents by minimizing perceptual loss between reference and generated images in feature space
    for images_batch in tqdm(split_to_batches(ref_images, args.batch_size),
                             total=len(ref_images) // args.batch_size):
        names = [
            os.path.splitext(os.path.basename(x))[0] for x in images_batch
        ]
        if args.output_video:
            video_out = {}
            for name in names:
                video_out[name] = cv2.VideoWriter(
                    os.path.join(args.video_dir, f'{name}.avi'),
                    cv2.VideoWriter_fourcc(*args.video_codec),
                    args.video_frame_rate, (args.video_size, args.video_size))

        perceptual_model.set_reference_images(images_batch)
        dlatents = None
        if (args.load_last != ''):  # load previous dlatents for initialization
            for name in names:
                dl = np.expand_dims(np.load(
                    os.path.join(args.load_last, f'{name}.npy')),
                                    axis=0)
                if (dlatents is None):
                    dlatents = dl
                else:
                    dlatents = np.vstack((dlatents, dl))
        else:
            if (ff_model is None):
                if os.path.exists(args.load_resnet):
                    print("Loading ResNet Model:")
                    ff_model = load_model(args.load_resnet)
                    from keras.applications.resnet50 import preprocess_input
            if (ff_model is None):
                if os.path.exists(args.load_effnet):
                    import efficientnet
                    print("Loading EfficientNet Model:")
                    ff_model = load_model(args.load_effnet)
                    from efficientnet import preprocess_input
            if (ff_model
                    is not None):  # predict initial dlatents with ResNet model
                dlatents = ff_model.predict(
                    preprocess_input(
                        load_images(images_batch,
                                    image_size=args.resnet_image_size)))
        if dlatents is not None:
            generator.set_dlatents(dlatents)
        op = perceptual_model.optimize(generator.dlatent_variable,
                                       iterations=args.iterations)
        pbar = tqdm(op, leave=False, total=args.iterations)
        vid_count = 0
        best_loss = None
        best_dlatent = None
        for loss_dict in pbar:
            pbar.set_description(" ".join(names) + ": " + "; ".join(
                ["{} {:.4f}".format(k, v) for k, v in loss_dict.items()]))
            if best_loss is None or loss_dict["loss"] < best_loss:
                best_loss = loss_dict["loss"]
                best_dlatent = generator.get_dlatents()
            if args.output_video and (vid_count % args.video_skip == 0):
                batch_frames = generator.generate_images()
                for i, name in enumerate(names):
                    video_frame = PIL.Image.fromarray(
                        batch_frames[i], 'RGB').resize(
                            (args.video_size, args.video_size),
                            PIL.Image.LANCZOS)
                    video_out[name].write(
                        cv2.cvtColor(
                            np.array(video_frame).astype('uint8'),
                            cv2.COLOR_RGB2BGR))
            generator.stochastic_clip_dlatents()
        print(" ".join(names), " Loss {:.4f}".format(best_loss))

        if args.output_video:
            for name in names:
                video_out[name].release()

        # Generate images from found dlatents and save them
        generator.set_dlatents(best_dlatent)
        generated_images = generator.generate_images()
        generated_dlatents = generator.get_dlatents()
        for img_array, dlatent, img_name in zip(generated_images,
                                                generated_dlatents, names):
            img = PIL.Image.fromarray(img_array, 'RGB')
            img.save(
                os.path.join(args.generated_images_dir, f'{img_name}.png'),
                'PNG')
            np.save(os.path.join(args.dlatent_dir, f'{img_name}.npy'), dlatent)

        generator.reset_dlatents()
Пример #13
0
def main():
    tflib.init_tf()
    os.makedirs(config.result_dir, exist_ok=True)
    draw_uncurated_result_figure(os.path.join(config.result_dir,
                                              'figure02-uncurated-ffhq.png'),
                                 load_Gs(url_ffhq),
                                 cx=0,
                                 cy=0,
                                 cw=1024,
                                 ch=1024,
                                 rows=3,
                                 lods=[0, 1, 2, 2, 3, 3],
                                 seed=5)
    draw_style_mixing_figure(
        os.path.join(config.result_dir, 'figure03-style-mixing.png'),
        load_Gs(url_ffhq),
        w=1024,
        h=1024,
        src_seeds=[639, 701, 687, 615, 2268],
        dst_seeds=[888, 829, 1898, 1733, 1614, 845],
        style_ranges=[range(0, 4)] * 3 + [range(4, 8)] * 2 + [range(8, 18)])
    draw_noise_detail_figure(os.path.join(config.result_dir,
                                          'figure04-noise-detail.png'),
                             load_Gs(url_ffhq),
                             w=1024,
                             h=1024,
                             num_samples=100,
                             seeds=[1157, 1012])
    draw_noise_components_figure(
        os.path.join(config.result_dir, 'figure05-noise-components.png'),
        load_Gs(url_ffhq),
        w=1024,
        h=1024,
        seeds=[1967, 1555],
        noise_ranges=[range(0, 18),
                      range(0, 0),
                      range(8, 18),
                      range(0, 8)],
        flips=[1])
    draw_truncation_trick_figure(os.path.join(config.result_dir,
                                              'figure08-truncation-trick.png'),
                                 load_Gs(url_ffhq),
                                 w=1024,
                                 h=1024,
                                 seeds=[91, 388],
                                 psis=[1, 0.7, 0.5, 0, -0.5, -1])
    draw_uncurated_result_figure(os.path.join(
        config.result_dir, 'figure10-uncurated-bedrooms.png'),
                                 load_Gs(url_bedrooms),
                                 cx=0,
                                 cy=0,
                                 cw=256,
                                 ch=256,
                                 rows=5,
                                 lods=[0, 0, 1, 1, 2, 2, 2],
                                 seed=0)
    draw_uncurated_result_figure(os.path.join(config.result_dir,
                                              'figure11-uncurated-cars.png'),
                                 load_Gs(url_cars),
                                 cx=0,
                                 cy=64,
                                 cw=512,
                                 ch=384,
                                 rows=4,
                                 lods=[0, 1, 2, 2, 3, 3],
                                 seed=2)
    draw_uncurated_result_figure(os.path.join(config.result_dir,
                                              'figure12-uncurated-cats.png'),
                                 load_Gs(url_cats),
                                 cx=0,
                                 cy=0,
                                 cw=256,
                                 ch=256,
                                 rows=5,
                                 lods=[0, 0, 1, 1, 2, 2, 2],
                                 seed=1)
Пример #14
0
def main_binary():
    # Initialize Tensorflow
    tflib.init_tf()

    # Load pre-trained network
    dir = 'results/00005-sgancelebahq-binary-1gpu-cond-wgangp/'
    dir = 'results/00006-sgancelebahq-binary-1gpu-cond-wgangp/'

    fn = 'network-snapshot-006926.pkl'
    _, _, Gs = pickle.load(open(os.path.join(dir, fn), 'rb'))

    # Print network details
    Gs.print_layers()

    # Create binary attributes
    # eyeglasses, male, black_hair, smiling, young

    classes = {
        '5_o_Clock_Shadow': 0,
        'Arched_Eyebrows': 0,
        'Attractive': 1,
        'Bags_Under_Eyes': 0,
        'Bald': 0,
        'Bangs': 0,
        'Big_Lips': 0,
        'Big_Nose': 0,
        'Black_Hair': 0,
        'Blond_Hair': 0,
        'Blurry': 0,
        'Brown_Hair': 1,
        'Bushy_Eyebrows': 0,
        'Chubby': 0,
        'Double_Chin': 0,
        'Eyeglasses': 0,
        'Goatee': 0,
        'Gray_Hair': 0,
        'Heavy_Makeup': 1,
        'High_Cheekbones': 1,
        'Male': 0,
        'Mouth_Slightly_Open': 1,
        'Mustache': 0,
        'Narrow_Eyes': 0,
        'No_Beard': 0,
        'Oval_Face': 1,
        'Pale_Skin': 0,
        'Pointy_Nose': 0,
        'Receding_Hairline': 0,
        'Rosy_Cheeks': 0,
        'Sideburns': 0,
        'Smiling': 0,
        'Straight_Hair': 0,
        'Wavy_Hair': 1,
        'Wearing_Earrings': 0,
        'Wearing_Hat': 0,
        'Wearing_Lipstick': 1,
        'Wearing_Necklace': 0,
        'Wearing_Necktie': 0,
        'Young': 1
    }

    print([attr for (attr, key) in classes.items() if key == 1])

    binary = np.array(list(classes.values())).reshape(1, -1)

    for i, rnd in enumerate([np.random.RandomState(i) for i in np.arange(20)]):

        latent = rnd.randn(1, Gs.input_shape[1])

        image = Gs.run(latent,
                       binary,
                       truncation_psi=0.7,
                       randomize_noise=True,
                       output_transform=fmt)
        image = image.reshape(256, 256, 3)

        png_filename = os.path.join(dir, 'examples/example{}.png'.format(i))
        PIL.Image.fromarray(image, 'RGB').save(png_filename)
Пример #15
0
parser.add_argument(
    'results_dir',
    help='Directory with network checkpoints for weight averaging')
parser.add_argument('--filespec',
                    default='network*.pkl',
                    help='The files to average')
parser.add_argument('--output_model',
                    default='network_avg.pkl',
                    help='The averaged model to output')
parser.add_argument('--count',
                    default=6,
                    help='Average the last n checkpoints',
                    type=int)

args, other_args = parser.parse_known_args()
swa_epochs = args.count
filepath = args.output_model
files = glob.glob(os.path.join(args.results_dir, args.filespec))
if (len(files) > swa_epochs):
    files = files[-swa_epochs:]
files.sort()
print(files)
init_tf()
models = fetch_models_from_files(files)
swa_models = apply_swa_to_checkpoints(models)

print('Final model parameters set to stochastic weight average.')
with open(filepath, 'wb') as f:
    pickle.dump(swa_models, f)
print('Final stochastic averaged weights saved to file.')
Пример #16
0
parser.add_argument('--minibatch_size', default=16, help='Size of minibatches for training and generation', type=int)
parser.add_argument('--seed', default=-1, help='Pick a random seed for reproducibility (-1 for no random seed selected)', type=int)
parser.add_argument('--loop', default=-1, help='Run this many iterations (-1 for infinite, halt with CTRL-C)', type=int)

args, other_args = parser.parse_known_args()

os.makedirs(args.data_dir, exist_ok=True)

if args.seed == -1:
    args.seed = None

if args.use_fp16:
    K.set_floatx('float16')
    K.set_epsilon(1e-4) 

tflib.init_tf()

model = get_resnet_model(args.model_path, model_res=args.model_res, depth=args.model_depth, size=args.model_size, activation=args.activation, optimizer=args.optimizer, loss=args.loss)

with dnnlib.util.open_url(args.model_url, cache_dir=config.cache_dir) as f:
    generator_network, discriminator_network, Gs_network = pickle.load(f)

def load_Gs():
    return Gs_network

if args.freeze_first:
    model.layers[1].trainable = False
    model.compile(loss=args.loss, metrics=[], optimizer=args.optimizer)

model.summary()