Exemple #1
0
def interpolate_latents(run_id,
                        snapshot,
                        video_fps=30,
                        filter_frames=30,
                        num_frames=60 * 30,
                        drange_net=[-1, 1],
                        image_grid_size=None,
                        zoom=None,
                        video_bitrate='16M'):

    import moviepy.editor  # pip install moviepy

    # Choose parameters.
    net = imgapi_load_net(run_id=run_id, snapshot=snapshot)
    w, h = net.G.output_shape[3], net.G.output_shape[2]
    if image_grid_size is None and zoom is None: image_grid_size = (1, 1)
    if zoom is None: zoom = max(min(1920 / w, 1080 / h), 1)
    if image_grid_size is None:
        image_grid_size = np.clip(int(np.floor(1920 / (w * zoom))), 1,
                                  16), np.clip(
                                      int(np.floor(1080 / (h * zoom))), 1, 16)

    # Generate latent vectors (frame, image, channel, component).
    print 'Generating latent vectors...'
    latents = np.random.randn(num_frames, np.prod(image_grid_size),
                              *net.G.input_shape[1:]).astype(np.float32)
    latents = scipy.ndimage.gaussian_filter(latents, [filter_frames] +
                                            [0] * len(net.G.input_shape),
                                            mode='wrap')
    latents /= np.sqrt(np.mean(latents**2))

    # Create video.
    print 'Generating video...'
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)

    def make_frame(t):
        frame_idx = np.clip(int(np.round(t * video_fps)), 0, num_frames - 1)
        images = net.gen_fn(latents[frame_idx],
                            net.example_labels[:latents.shape[1]])
        grid = misc.create_image_grid(images, grid_size=image_grid_size)
        if zoom != 1: grid = scipy.ndimage.zoom(grid, [1, zoom, zoom], order=0)
        grid = grid.clip(0, 255).transpose(1, 2, 0)  # CHW => HWC
        if grid.shape[2] == 1: grid = grid.repeat(3, 2)  # grayscale => RGB
        return grid

    video = moviepy.editor.VideoClip(make_frame,
                                     duration=float(num_frames) / video_fps)
    video.write_videofile(os.path.join(
        result_subdir,
        os.path.basename(result_subdir) + '.mp4'),
                          fps=video_fps,
                          codec='libx264',
                          bitrate=video_bitrate)

    # Done.
    print 'Done.'
    with open(os.path.join(result_subdir, '_video-done.txt'), 'wt'):
        pass
Exemple #2
0
def predict_gan():
    separate_funcs          = False
    drange_net              = [-1,1]
    drange_viz              = [-1,1]
    image_grid_size         = (1 ,1)
    image_grid_type         = 'default'
    resume_network          = './pre-trained_weight' # adding the ./ to define the pre-trained-weight folder at root level
    
    np.random.seed(config.random_seed)

    if resume_network:
        print("Resuming weight from:"+resume_network)
        G = Generator(num_channels=3, resolution=128, label_size=0, **config.G)
        G = load_G_weights(G,resume_network,True)

    print(G.summary())

    # Misc init.

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d"%(w,h))
            image_grid_size = np.clip(int(1920 // w), 3, 16).astype('int'), np.clip(1080 / h, 2, 16).astype('int')
        
        print("image_grid_size:",image_grid_size)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir('pre-trained_result', config.run_desc)

    for i in range(1,6):
        snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape)
        snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'pre-trained_%03d.png'%i), drange=drange_viz, grid_size=image_grid_size)
        
        # use streamlit to show images generated
        # st.image(os.path.join(result_subdir, 'pre-trained_%03d.png'%i))
        st.header('IG Post #' + str(i))
        im = Image.open(os.path.join(result_subdir, 'pre-trained_%03d.png'%i))
        st.image(im.resize((512, 512), Image.ANTIALIAS)) # with gpt2, the images are generated too slow
        # call gpt2-simple on a pre-trained weight
        gen_file = os.path.join(result_subdir,'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow()))
        
        gpt2.generate_to_file(sess, destination_path=gen_file, run_name='tree_run1')
        # read contents of generated text
        with open(gen_file, 'r') as content:
            st.write(content.read())
Exemple #3
0
def predict_gan():
    separate_funcs = False
    drange_net = [-1, 1]
    drange_viz = [-1, 1]
    image_grid_size = (1, 1)
    image_grid_type = 'default'
    resume_network = './pre-trained_weight'  # adding the ./ to define the pre-trained-weight folder at root level

    np.random.seed(config.random_seed)

    if resume_network:
        print("Resuming weight from:" + resume_network)
        G = Generator(num_channels=3, resolution=128, label_size=0, **config.G)
        G = load_G_weights(G, resume_network, True)

    print(G.summary())

    # Misc init.

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d" % (w, h))
            image_grid_size = np.clip(int(1920 // w), 3,
                                      16).astype('int'), np.clip(
                                          1080 / h, 2, 16).astype('int')

        print("image_grid_size:", image_grid_size)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir('pre-trained_result',
                                              config.run_desc)

    for i in range(1, 6):
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
        snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        misc.save_image_grid(snapshot_fake_images,
                             os.path.join(result_subdir,
                                          'pre-trained_%03d.png' % i),
                             drange=drange_viz,
                             grid_size=image_grid_size)
Exemple #4
0
def predict_gan():
    separate_funcs          = False
    drange_net              = [-1,1]
    drange_viz              = [-1,1]
    image_grid_size         = None
    image_grid_type         = 'default'
    resume_network          = path2preTrainedWeights
    
    #np.random.seed(config.random_seed)

    if resume_network:
        print("Resuming weight from:"+resume_network)
        G = Generator(num_channels=3, resolution=resolution, label_size=0, **config.G)
        #G = load_G_weights(G,resume_network,True)
        D = Discriminator(num_channels=3, resolution=resolution, label_size=0, **config.D)        
        G,D = load_GD_weights(G,D,resume_network,True)
        

    print(G.summary())
    print(D.summary())

    # Misc init.

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d"%(w,h))
            image_grid_size = np.clip(int(1920 // w), 3, 16).astype('int'), np.clip(1080 / h, 2, 16).astype('int')
            image_grid_size=1,1
        
        print("image_grid_size:",image_grid_size)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir('pre-trained_result', config.run_desc)

    for i in range(1,numOfGenImages):
        snapshot_fake_latents = random_latents(np.prod(image_grid_size), G.input_shape)
        snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        snapshot_fake_scores = D.predict_on_batch(snapshot_fake_images)[0,0,0,0]
        misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'pre-trained_%03d.png'%i), drange=drange_viz, grid_size=image_grid_size)
Exemple #5
0
def train_gan(separate_funcs=False,
              D_training_repeats=1,
              G_learning_rate_max=0.0010,
              D_learning_rate_max=0.0010,
              G_smoothing=0.999,
              adam_beta1=0.0,
              adam_beta2=0.99,
              adam_epsilon=1e-8,
              minibatch_default=16,
              minibatch_overrides={},
              rampup_kimg=40,
              rampdown_kimg=0,
              lod_initial_resolution=4,
              lod_training_kimg=400,
              lod_transition_kimg=400,
              total_kimg=10000,
              dequantize_reals=False,
              gdrop_beta=0.9,
              gdrop_lim=0.5,
              gdrop_coef=0.2,
              gdrop_exp=2.0,
              drange_net=[-1, 1],
              drange_viz=[-1, 1],
              image_grid_size=None,
              tick_kimg_default=50,
              tick_kimg_overrides={
                  32: 20,
                  64: 10,
                  128: 10,
                  256: 5,
                  512: 2,
                  1024: 1
              },
              image_snapshot_ticks=4,
              network_snapshot_ticks=40,
              image_grid_type='default',
              resume_network_pkl=None,
              resume_kimg=0.0,
              resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()
    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)
    misc.print_network_topology_info(G.output_layers)
    misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    elif image_grid_type == 'category':
        W = training_set.labels.shape[1]
        H = W if image_grid_size is None else image_grid_size[1]
        image_grid_size = W, H
        snapshot_fake_latents = random_latents(W * H, G.input_shape)
        snapshot_fake_labels = np.zeros((W * H, W),
                                        dtype=training_set.labels.dtype)
        example_real_images = np.zeros((W * H, ) + training_set.shape[1:],
                                       dtype=training_set.dtype)
        for x in xrange(W):
            snapshot_fake_labels[x::W, x] = 1.0
            indices = np.arange(
                training_set.shape[0])[training_set.labels[:, x] != 0]
            for y in xrange(H):
                example_real_images[x + y * W] = training_set.h5_lods[0][
                    np.random.choice(indices)]
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    if config.D.get('mbdisc_kernels', None):
        print 'Initializing minibatch discrimination...'
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(initial_lod))
        D.eval(real_images_var, deterministic=False, init=True)
        init_layers = lasagne.layers.get_all_layers(D.output_layers)
        init_updates = [
            update for layer in init_layers
            for update in getattr(layer, 'init_updates', [])
        ]
        init_fn = theano.function(inputs=[real_images_var],
                                  outputs=None,
                                  updates=init_updates)
        init_reals = training_set.get_random_minibatch(500, lod=initial_lod)
        init_reals = misc.adjust_dynamic_range(init_reals, drange_orig,
                                               drange_net)
        init_fn(init_reals)
        del init_reals

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / 1000.0) / (lod_training_kimg +
                                          lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()
            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            # Compile training funcs.
            if not separate_funcs:
                GD_train_fn = theano.function([
                    real_images_var, real_labels_var, fake_latents_var,
                    fake_labels_var
                ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                              updates=G_updates + D_updates +
                                              Gs.updates,
                                              on_unused_input='ignore')
            else:
                D_train_fn = theano.function([
                    real_images_var, real_labels_var, fake_latents_var,
                    fake_labels_var
                ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                             updates=D_updates,
                                             on_unused_input='ignore')
                G_train_fn = theano.function(
                    [fake_latents_var, fake_labels_var], [],
                    updates=G_updates + Gs.updates,
                    on_unused_input='ignore')

        # Invoke training funcs.
        if not separate_funcs:
            assert D_training_repeats == 1
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_train_out = GD_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        else:
            for idx in xrange(D_training_repeats):
                mb_reals, mb_labels = training_set.get_random_minibatch(
                    minibatch_size,
                    lod=cur_lod,
                    shrink_based_on_lod=True,
                    labels=True)
                mb_train_out = D_train_fn(
                    mb_reals, mb_labels,
                    random_latents(minibatch_size, G.input_shape),
                    random_labels(minibatch_size, training_set))
                cur_nimg += minibatch_size
                tick_train_out.append(mb_train_out)
            G_train_fn(random_latents(minibatch_size, G.input_shape),
                       random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
Exemple #6
0
def train_gan(separate_funcs=False,
              D_training_repeats=1,
              G_learning_rate_max=0.0010,
              D_learning_rate_max=0.0010,
              G_smoothing=0.999,
              adam_beta1=0.0,
              adam_beta2=0.99,
              adam_epsilon=1e-8,
              minibatch_default=16,
              minibatch_overrides={},
              rampup_kimg=40 / speed_factor,
              rampdown_kimg=0,
              lod_initial_resolution=4,
              lod_training_kimg=400 / speed_factor,
              lod_transition_kimg=400 / speed_factor,
              total_kimg=10000 / speed_factor,
              dequantize_reals=False,
              gdrop_beta=0.9,
              gdrop_lim=0.5,
              gdrop_coef=0.2,
              gdrop_exp=2.0,
              drange_net=[-1, 1],
              drange_viz=[-1, 1],
              image_grid_size=None,
              tick_kimg_default=50 / speed_factor,
              tick_kimg_overrides={
                  32: 20,
                  64: 10,
                  128: 10,
                  256: 5,
                  512: 2,
                  1024: 1
              },
              image_snapshot_ticks=1,
              network_snapshot_ticks=4,
              image_grid_type='default',
              resume_network=None,
              resume_kimg=0.0,
              resume_time=0.0):

    training_set, drange_orig = load_dataset()

    print("-" * 50)
    print("resume_kimg: %s" % resume_kimg)
    print("-" * 50)

    if resume_network:
        print("Resuming weight from:" + resume_network)
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)
        G, D = load_GD_weights(G, D,
                               os.path.join(config.result_dir, resume_network),
                               True)
        print("pre-trained weights loaded!")
        print("-" * 50)
    else:
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)

    G_train, D_train = PG_GAN(G, D, config.G['latent_size'], 0,
                              training_set.shape[1], training_set.shape[3])

    print(G.summary())
    print(D.summary())

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    G_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)
    D_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)

    if config.loss['type'] == 'wass':
        G_loss = wasserstein_loss
        D_loss = wasserstein_loss
    elif config.loss['type'] == 'iwass':
        G_loss = multiple_loss
        D_loss = [mean_loss, 'mse']
        D_loss_weight = [1.0, config.loss['iwass_lambda']]

    G.compile(G_opt, loss=G_loss)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=D_loss_weight)

    cur_nimg = int(resume_kimg * 1000)
    print("-" * 50)
    print("resume_kimg: %s" % resume_kimg)
    print("current nimg: %s" % cur_nimg)
    print("-" * 50)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d" % (w, h))
            image_grid_size = np.clip(int(1920 // w), 3,
                                      16).astype('int'), np.clip(
                                          1080 / h, 2, 16).astype('int')

        print("image_grid_size:", image_grid_size)

        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)

    print("example_real_images.shape:", example_real_images.shape)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)

    snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                           G.input_shape)
    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir,
                                      'fakes%06d.png' % (cur_nimg / 1000)),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    nimg_h = 0

    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                      rampdown_kimg)

        K.set_value(G.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))
        K.set_value(G_train.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))

        K.set_value(D_train.optimizer.lr,
                    np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G_train, 'cur_lod'):
            K.set_value(G_train.cur_lod, np.float32(cur_lod))
        if hasattr(D_train, 'cur_lod'):
            K.set_value(D_train.cur_lod, np.float32(cur_lod))

        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        # train D
        d_loss = None
        for idx in range(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch_channel_last(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_latents = random_latents(minibatch_size, G.input_shape)
            mb_labels_rnd = random_labels(minibatch_size, training_set)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1)
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2)

            mb_fakes = G.predict_on_batch([mb_latents])

            epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1))
            interpolation = epsilon * mb_reals + (1 - epsilon) * mb_fakes
            mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig,
                                                 drange_net)
            d_loss, d_diff, d_norm = D_train.train_on_batch(
                [mb_fakes, mb_reals, interpolation], [
                    np.ones((minibatch_size, 1, 1, 1)),
                    np.ones((minibatch_size, 1))
                ])
            d_score_real = D.predict_on_batch(mb_reals)
            d_score_fake = D.predict_on_batch(mb_fakes)
            print("%d real score: %d fake score: %d" %
                  (idx, np.mean(d_score_real), np.mean(d_score_fake)))
            print("-" * 50)
            cur_nimg += minibatch_size

        #train G

        mb_latents = random_latents(minibatch_size, G.input_shape)
        mb_labels_rnd = random_labels(minibatch_size, training_set)

        g_loss = G_train.train_on_batch([mb_latents], (-1) * np.ones(
            (mb_latents.shape[0], 1, 1, 1)))

        print("%d images [D loss: %f] [G loss: %f]" %
              (cur_nimg, d_loss, g_loss))
        print("-" * 50)

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            K.set_value(D.gdrop_strength, np.float32(gdrop_strength))

        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = G.predict_on_batch(
                    snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                save_GD_weights(
                    G, D,
                    os.path.join(result_subdir,
                                 'network-snapshot-%06d' % (cur_nimg / 1000)))

    save_GD(G, D, os.path.join(result_subdir, 'network-final'))
    training_set.close()
    print('Done.')
Exemple #7
0
def train_gan(
        separate_funcs=False,
        D_training_repeats=1,
        G_learning_rate_max=0.0010,
        D_learning_rate_max=0.0010,
        G_smoothing=0.999,
        adam_beta1=0.0,
        adam_beta2=0.99,
        adam_epsilon=1e-8,
        minibatch_default=16,
        minibatch_overrides={},
        rampup_kimg=40 / speed_factor,
        rampdown_kimg=0,
        lod_initial_resolution=4,
        lod_training_kimg=400 / speed_factor,
        lod_transition_kimg=400 / speed_factor,
        total_kimg=10000 / speed_factor,
        dequantize_reals=False,
        gdrop_beta=0.9,
        gdrop_lim=0.5,
        gdrop_coef=0.2,
        gdrop_exp=2.0,
        drange_net=[-1, 1],
        drange_viz=[-1, 1],
        image_grid_size=None,
        tick_kimg_default=50 / speed_factor,
        tick_kimg_overrides={
            32: 20,
            64: 10,
            128: 10,
            256: 5,
            512: 2,
            1024: 1
        },
        image_snapshot_ticks=1,
        network_snapshot_ticks=4,
        image_grid_type='default',
        resume_network='000-celeba/network-snapshot-000488',
        #resume_network          = None,
        resume_kimg=511.6,
        resume_time=0.0):

    training_set, drange_orig = load_dataset()
    #print("training_set.shape:",training_set.shape)

    if resume_network:
        print("Resuming weight from:" + resume_network)
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)
        G, D = load_GD_weights(G, D,
                               os.path.join(config.result_dir, resume_network),
                               True)
    else:
        G = Generator(num_channels=training_set.shape[3],
                      resolution=training_set.shape[1],
                      label_size=training_set.labels.shape[1],
                      **config.G)
        D = Discriminator(num_channels=training_set.shape[3],
                          resolution=training_set.shape[1],
                          label_size=training_set.labels.shape[1],
                          **config.D)
        #missing Gs
    G_train, D_train = PG_GAN(G, D, config.G['latent_size'], 0,
                              training_set.shape[1], training_set.shape[3])
    #G_tmp = Model(inputs=[G.get_input_at(0)],
    #              outputs = [G.get_layer('G6bPN').output])
    print(G.summary())
    print(D.summary())
    #print(pg_GAN.summary())

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    G_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)
    D_opt = optimizers.Adam(lr=0.0,
                            beta_1=adam_beta1,
                            beta_2=adam_beta2,
                            epsilon=adam_epsilon)
    # GAN_opt = optimizers.Adam(lr = 0.0,beta_1 = 0.0,beta_2 = 0.99)

    if config.loss['type'] == 'wass':
        G_loss = wasserstein_loss
        D_loss = wasserstein_loss
    elif config.loss['type'] == 'iwass':
        G_loss = multiple_loss
        D_loss = [mean_loss, 'mse']
        D_loss_weight = [1.0, config.loss['iwass_lambda']]

    G.compile(G_opt, loss=G_loss)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=D_loss_weight)

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time

    #real_image_input = Input((training_set.shape[2],training_set.shape[2],training_set.shape[-1]),name = "real_image_input")
    #real_label_input = Input((training_set.labels.shape[1]),name = "real_label_input")
    #fake_latent_input = Input((config.G['latent_size']),name = "fake_latent_input")
    #fake_labels_input = Input((training_set.labels.shape[1]),name = "fake_label_input")

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print("w:%d,h:%d" % (w, h))
            image_grid_size = np.clip(int(1920 // w), 3,
                                      16).astype('int'), np.clip(
                                          1080 / h, 2, 16).astype('int')

        print("image_grid_size:", image_grid_size)

        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch_channel_last(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    #print("image_grid_size:",image_grid_size)

    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    #snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    #result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc)

    #('example_real_images.shape:', (120, 3, 128, 128))

    #print("example_real_images:",example_real_images)
    print("example_real_images.shape:", example_real_images.shape)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    #misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_viz, grid_size=image_grid_size)

    #NINweight = G.get_layer('Glod0NIN').get_weights()[0]
    #print("Glod0NIN weight:",NINweight)

    snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                           G.input_shape)
    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir,
                                      'fakes%06d.png' % (cur_nimg / 1000)),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    nimg_h = 0

    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                      rampdown_kimg)
        #G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        K.set_value(G.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))
        K.set_value(G_train.optimizer.lr,
                    np.float32(lrate_coef * G_learning_rate_max))
        #D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))
        K.set_value(D_train.optimizer.lr,
                    np.float32(lrate_coef * D_learning_rate_max))
        if hasattr(G_train, 'cur_lod'):
            K.set_value(G_train.cur_lod, np.float32(cur_lod))
        if hasattr(D_train, 'cur_lod'):
            K.set_value(D_train.cur_lod, np.float32(cur_lod))

        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        #    # Pre-process reals.
        #    real_images_expr = real_images_var
        #    if dequantize_reals:
        #        epsilon_noise = K.random_uniform_variable(real_image_input.shape(), low=-0.5, high=0.5, dtype='float32', seed=np.random.randint(1, 2147462579))
        #        epsilon_noise = rnd.uniform(size=real_images_expr.shape, low=-0.5, high=0.5, dtype='float32')
        #        real_images_expr = T.cast(real_images_expr, 'float32') + epsilon_noise # match original implementation of Improved Wasserstein
        #    real_images_expr = misc.adjust_dynamic_range(real_images_expr, drange_orig, drange_net)
        #    if min_lod > 0: # compensate for shrink_based_on_lod
        #        real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=2)
        #        real_images_expr = T.extra_ops.repeat(real_images_expr, 2**min_lod, axis=3)
        # train D
        d_loss = None
        for idx in range(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch_channel_last(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)
            mb_latents = random_latents(minibatch_size, G.input_shape)
            mb_labels_rnd = random_labels(minibatch_size, training_set)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1)
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2)

            mb_fakes = G.predict_on_batch([mb_latents])

            epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1))
            interpolation = epsilon * mb_reals + (1 - epsilon) * mb_fakes
            mb_reals = misc.adjust_dynamic_range(mb_reals, drange_orig,
                                                 drange_net)
            d_loss, d_diff, d_norm = D_train.train_on_batch(
                [mb_fakes, mb_reals, interpolation], [
                    np.ones((minibatch_size, 1, 1, 1)),
                    np.ones((minibatch_size, 1))
                ])
            d_score_real = D.predict_on_batch(mb_reals)
            d_score_fake = D.predict_on_batch(mb_fakes)
            print("real score: %d fake score: %d" %
                  (np.mean(d_score_real), np.mean(d_score_fake)))
            #d_loss_real = D.train_on_batch(mb_reals, -np.ones((mb_reals.shape[0],1,1,1)))
            #d_loss_fake = D.train_on_batch(mb_fakes, np.ones((mb_fakes.shape[0],1,1,1)))
            #d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            cur_nimg += minibatch_size

        #train G

        mb_latents = random_latents(minibatch_size, G.input_shape)
        mb_labels_rnd = random_labels(minibatch_size, training_set)

        #if cur_nimg//100 !=nimg_h:
        #    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        #    print(np.mean(snapshot_fake_images,axis=-1))
        #    G_lod = G_tmp.predict(snapshot_fake_latents)
        #    print("G6bPN:",np.mean(G_lod,axis=-1))
        #    misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d_beforeGtrain.png' % (cur_nimg)), drange=drange_viz, grid_size=image_grid_size)

        g_loss = G_train.train_on_batch([mb_latents], (-1) * np.ones(
            (mb_latents.shape[0], 1, 1, 1)))

        print("%d [D loss: %f] [G loss: %f]" % (cur_nimg, d_loss, g_loss))
        #print(cur_nimg)
        #print(g_loss)
        #print(d_loss)
        # Fade in D noise if we're close to becoming unstable
        #if cur_nimg//100 !=nimg_h:
        #    nimg_h=cur_nimg//100
        #    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
        #    print(np.mean(snapshot_fake_images,axis=-1))
        #    G_lod = G_tmp.predict(snapshot_fake_latents)
        #    print("G6bPN:",np.mean(G_lod,axis=-1))
        #    NINweight = G.get_layer('Glod0NIN').get_weights()[0]
        #    print(NINweight)
        #    misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg)), drange=drange_viz, grid_size=image_grid_size)

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            K.set_value(D.gdrop_strength, np.float32(gdrop_strength))

        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []
            '''
            # Print progress.
            print ('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size, format_time(cur_time - train_start_time), tick_time, tick_time / tick_kimg, gdrop_strength) + tick_train_avg))
            '''

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = G.predict_on_batch(
                    snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                save_GD_weights(
                    G, D,
                    os.path.join(result_subdir,
                                 'network-snapshot-%06d' % (cur_nimg / 1000)))

    save_GD(G, D, os.path.join(result_subdir, 'network-final'))
    training_set.close()
    print('Done.')
Exemple #8
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    unlabeled_training_set = dataset.load_dataset(
        data_dir=config.unlabeled_data_dir,
        verbose=True,
        **config.unlabeled_dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            print("Training-Set Label Size: ", training_set.label_size)
            print("Unlabeled-Training-Set Label Size: ",
                  unlabeled_training_set.label_size)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])

        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        unlabeled_reals, _ = unlabeled_training_set.get_minibatch_tf()

        reals_split = tf.split(reals, config.num_gpus)
        unlabeled_reals_split = tf.split(unlabeled_reals, config.num_gpus)

        labels_split = tf.split(labels, config.num_gpus)

    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    G_opt_pggan = tfutil.Optimizer(name='TrainG_pggan',
                                   learning_rate=lrate_in,
                                   **config.G_opt)
    D_opt_pggan = tfutil.Optimizer(name='TrainD_pggan',
                                   learning_rate=lrate_in,
                                   **config.D_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    print("CUDA_VISIBLE_DEVICES: ", os.environ['CUDA_VISIBLE_DEVICES'])

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]

            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            unlabeled_reals_gpu = process_reals(
                unlabeled_reals_split[gpu], lod_in, mirror_augment,
                unlabeled_training_set.dynamic_range, drange_net)

            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.G_loss)
            with tf.name_scope('G_loss_pggan'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss_pggan = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss_pggan)

            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.D_loss)
            with tf.name_scope('D_loss_pggan'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss_pggan = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    unlabeled_reals=unlabeled_reals_gpu,
                    **config.D_loss_pggan)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            G_opt_pggan.register_gradients(tf.reduce_mean(G_loss_pggan),
                                           G_gpu.trainables)
            D_opt_pggan.register_gradients(tf.reduce_mean(D_loss_pggan),
                                           D_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            print('GPU %d loaded!' % gpu)

    G_train_op = G_opt.apply_updates()
    G_train_op_pggan = G_opt_pggan.apply_updates()
    D_train_op_pggan = D_opt_pggan.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * TrainingSpeedInt, training_set,
                             **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.compat.v1.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print("Start Time: ",
          datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    print('Training...')
    cur_nimg = int(resume_kimg * TrainingSpeedInt)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0

    while cur_nimg < total_kimg * TrainingSpeedInt:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        sched2 = TrainingSchedule(cur_nimg, unlabeled_training_set,
                                  **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        unlabeled_training_set.configure(sched2.minibatch, sched2.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                G_opt_pggan.reset_optimizer_state()
                D_opt_pggan.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                # Run the Pggan loss if lod != 0 else run SSL loss with feature matching
                if sched.lod == 0:
                    tfutil.run(
                        [D_train_op, Gs_update_op], {
                            lod_in: sched.lod,
                            lrate_in: sched.D_lrate,
                            minibatch_in: sched.minibatch
                        })
                else:
                    tfutil.run(
                        [D_train_op_pggan, Gs_update_op], {
                            lod_in: sched.lod,
                            lrate_in: sched.D_lrate,
                            minibatch_in: sched.minibatch
                        })
                cur_nimg += sched.minibatch
                #tmp = min(tick_start_nimg + sched.tick_kimg * TrainingSpeedInt, total_kimg * TrainingSpeedInt)
                #print("Tick progress:  {}/{}".format(cur_nimg, tmp), end="\r", flush=True)
            # Run the Pggan loss if lod != 0 else run SSL loss with feature matching
            if sched.lod == 0:
                tfutil.run(
                    [G_train_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.G_lrate,
                        minibatch_in: sched.minibatch
                    })
            else:
                tfutil.run(
                    [G_train_op_pggan], {
                        lod_in: sched.lod,
                        lrate_in: sched.G_lrate,
                        minibatch_in: sched.minibatch
                    })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * TrainingSpeedInt)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * TrainingSpeedInt or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / TrainingSpeedFloat
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f date %s'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg',
                                      cur_nimg / TrainingSpeedFloat),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary(
                       'Timing/maintenance_sec', maintenance_time),
                   datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))

            #######################
            # VALIDATION ACCURACY #
            #######################

            # example ndim = 512 for an image that is 512x512 pixels
            # All images for SSL-PGGAN must be square
            ndim = 256
            correct = 0
            guesses = 0

            dir_tuple = (config.validation_dog, config.validation_cat)
            # If guessed the wrong class seeing if there is a bias
            FP_RATE = [[0], [0]]
            # For each class
            for indx, directory in enumerate(dir_tuple):
                # Go through every image that needs to be tested
                for filename in os.listdir(directory):
                    guesses += 1
                    #tensor = np.zeros((1, 3, 512, 512))
                    print(filename)
                    img = np.asarray(PIL.Image.open(directory +
                                                    filename)).reshape(
                                                        3, ndim, ndim)
                    img = np.expand_dims(
                        img, axis=0)  # makes the image (1,3,512,512)
                    K_logits_out, fake_logit_out, features_out = test_discriminator(
                        D, img)

                    #print("K Logits Out:",K_logits_out.eval())
                    sample_probs = tf.nn.softmax(K_logits_out)
                    #print("Softmax Output:", sample_probs.eval())
                    label = np.argmax(sample_probs.eval()[0], axis=0)
                    if label == indx:
                        correct += 1
                    else:
                        FP_RATE[indx][0] += 1
                    print("-----------------------------------")
                    print("GUESSED LABEL: ", label)
                    print("CORRECT LABEL: ", indx)
                    validation = (correct / guesses)
                    print("Total Correct: ", correct, "\n", "Total Guesses: ",
                          guesses, "\n", "Percent correct: ", validation)
                    print("False Positives: Dog, Cat", FP_RATE)
                    print()

            tfutil.autosummary('Accuracy/Validation', (correct / guesses))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir, 'fakes%06d.png' %
                                         (cur_nimg // TrainingSpeedInt)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs),
                              os.path.join(
                                  result_subdir, 'network-snapshot-%06d.pkl' %
                                  (cur_nimg // TrainingSpeedInt)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #9
0
def train_detector(
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=1,  # Total length of the training, measured in thousands of real images.
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    snapshot_size=16,  # Size of the snapshot image
    snapshot_ticks=2**13,  # Number of images before maintenance
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=1,  # How often to export network snapshots?
    save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()

    # Load the datasets
    training_set = dataset.load_dataset(tfrecord=config.tfrecord_train,
                                        verbose=True,
                                        **config.dataset)
    testing_set = dataset.load_dataset(tfrecord=config.tfrecord_test,
                                       verbose=True,
                                       repeat=False,
                                       shuffle_mb=0,
                                       **config.dataset)
    testing_set_len = len(testing_set)

    # TODO: data augmentation
    # TODO: testing set

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:  # TODO: save methods
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            N = misc.load_pkl(network_pkl)
        else:
            print('Constructing the network...'
                  )  # TODO: better network (like lod-wise network)
            N = tfutil.Network('N',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               **config.N)
    N.print_layers()

    print('Building TensorFlow graph...')
    # Training set up
    with tf.name_scope('Inputs'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        # minibatch_in            = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        reals, labels, bboxes = training_set.get_minibatch_tf(
        )  # TODO: increase the size of the batch by several loss computation and mean
    N_opt = tfutil.Optimizer(name='TrainN',
                             learning_rate=lrate_in,
                             **config.N_opt)

    with tf.device('/gpu:0'):
        reals, labels, gt_outputs, gt_ref = pre_process(
            reals, labels, bboxes, training_set.dynamic_range,
            [0, training_set.shape[-2]], drange_net)
        with tf.name_scope('N_loss'):  # TODO: loss inadapted
            N_loss = tfutil.call_func_by_name(N=N,
                                              reals=reals,
                                              gt_outputs=gt_outputs,
                                              gt_ref=gt_ref,
                                              **config.N_loss)

        N_opt.register_gradients(tf.reduce_mean(N_loss), N.trainables)
    N_train_op = N_opt.apply_updates()

    # Testing set up
    with tf.device('/gpu:0'):
        test_reals_tf, test_labels_tf, test_bboxes_tf = testing_set.get_minibatch_tf(
        )
        test_reals_tf, test_labels_tf, test_gt_outputs_tf, test_gt_ref_tf = pre_process(
            test_reals_tf, test_labels_tf, test_bboxes_tf,
            testing_set.dynamic_range, [0, testing_set.shape[-2]], drange_net)
        with tf.name_scope('N_test_loss'):
            test_loss = tfutil.call_func_by_name(N=N,
                                                 reals=test_reals_tf,
                                                 gt_outputs=test_gt_outputs_tf,
                                                 gt_ref=test_gt_ref_tf,
                                                 is_training=False,
                                                 **config.N_loss)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        N.setup_weight_histograms()

    test_reals, _, test_bboxes = testing_set.get_minibatch_np(snapshot_size)
    misc.save_img_bboxes(test_reals,
                         test_bboxes,
                         os.path.join(result_subdir, 'reals.png'),
                         snapshot_size,
                         adjust_range=False)

    test_reals = misc.adjust_dynamic_range(test_reals,
                                           training_set.dynamic_range,
                                           drange_net)
    test_preds, _ = N.run(test_reals, minibatch_size=snapshot_size)
    misc.save_img_bboxes(test_reals, test_preds,
                         os.path.join(result_subdir, 'fakes.png'),
                         snapshot_size)

    print('Training...')
    if resume_run_id is None:
        tfutil.run(tf.global_variables_initializer())

    cur_nimg = 0
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time

    # Choose training parameters and configure training ops.
    sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
    training_set.configure(sched.minibatch)

    _train_loss = 0

    while cur_nimg < total_kimg * 1000:

        # Run training ops.
        # for _ in range(minibatch_repeats):
        _, loss = tfutil.run([N_train_op, N_loss], {lrate_in: sched.N_lrate})
        _train_loss += loss
        cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        if (cur_nimg >= total_kimg * 1000) or (cur_nimg % snapshot_ticks == 0
                                               and cur_nimg > 0):

            cur_tick += 1
            cur_time = time.time()
            # tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            _train_loss = _train_loss / (cur_nimg - tick_start_nimg)
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            testing_set.configure(sched.minibatch)
            _test_loss = 0
            # testing_set_len = 1 # TMP
            for _ in range(0, testing_set_len, sched.minibatch):
                _test_loss += tfutil.run(test_loss)
            _test_loss /= testing_set_len

            # Report progress. # TODO: improved report display
            print(
                'tick %-5d kimg %-6.1f time %-10s sec/tick %-3.1f maintenance %-7.2f train_loss %.4f test_loss %.4f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/maintenance', maintenance_time),
                   tfutil.autosummary('TrainN/train_loss', _train_loss),
                   tfutil.autosummary('TrainN/test_loss', _test_loss)))

            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            if cur_tick % image_snapshot_ticks == 0:
                test_bboxes, test_refs = N.run(test_reals,
                                               minibatch_size=snapshot_size)
                misc.save_img_bboxes_ref(
                    test_reals, test_bboxes, test_refs,
                    os.path.join(result_subdir,
                                 'fakes%06d.png' % (cur_nimg // 1000)),
                    snapshot_size)
            if cur_tick % network_snapshot_ticks == 0:
                misc.save_pkl(
                    N,
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            _train_loss = 0

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # misc.save_pkl(N, os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #10
0
def train_gan(
        separate_funcs=False,
        D_training_repeats=1,
        G_learning_rate_max=0.0010,
        D_learning_rate_max=0.0010,
        G_smoothing=0.999,
        adam_beta1=0.0,
        adam_beta2=0.99,
        adam_epsilon=1e-8,
        minibatch_default=16,
        minibatch_overrides={},
        rampup_kimg=40 / speed_factor,
        rampdown_kimg=0,
        lod_initial_resolution=4,
        lod_training_kimg=400 / speed_factor,
        lod_transition_kimg=400 / speed_factor,
        #lod_training_kimg       = 40,
        #lod_transition_kimg     = 40,
        total_kimg=10000 / speed_factor,
        dequantize_reals=False,
        gdrop_beta=0.9,
        gdrop_lim=0.5,
        gdrop_coef=0.2,
        gdrop_exp=2.0,
        drange_net=[-1, 1],
        drange_viz=[-1, 1],
        image_grid_size=None,
        #tick_kimg_default       = 1,
        tick_kimg_default=50 / speed_factor,
        tick_kimg_overrides={
            32: 20,
            64: 10,
            128: 10,
            256: 5,
            512: 2,
            1024: 1
        },
        image_snapshot_ticks=4,
        network_snapshot_ticks=40,
        image_grid_type='default',
        #resume_network_pkl      = '006-celeb128-progressive-growing/network-snapshot-002009.pkl',
        resume_network_pkl=None,
        resume_kimg=0,
        resume_time=0.0):

    # Load dataset and build networks.
    training_set, drange_orig = load_dataset()

    print "*************** test the format of dataset ***************"
    print training_set
    print drange_orig
    # training_set是dataset模块解析h5之后的对象,
    # drange_orig 为training_set.get_dynamic_range()

    if resume_network_pkl:
        print 'Resuming', resume_network_pkl
        G, D, _ = misc.load_pkl(
            os.path.join(config.result_dir, resume_network_pkl))
    else:
        G = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.G)
        D = network.Network(num_channels=training_set.shape[1],
                            resolution=training_set.shape[2],
                            label_size=training_set.labels.shape[1],
                            **config.D)
    Gs = G.create_temporally_smoothed_version(beta=G_smoothing,
                                              explicit_updates=True)

    # G,D对象可以由misc解析pkl之后生成,也可以由network模块构造

    print G
    print D

    #misc.print_network_topology_info(G.output_layers)
    #misc.print_network_topology_info(D.output_layers)

    # Setup snapshot image grid.
    # 设置中途输出图片的格式
    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[3], G.output_shape[2]
            image_grid_size = np.clip(1920 / w, 3,
                                      16), np.clip(1080 / h, 2, 16)
        example_real_images, snapshot_fake_labels = training_set.get_random_minibatch(
            np.prod(image_grid_size), labels=True)
        snapshot_fake_latents = random_latents(np.prod(image_grid_size),
                                               G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    # Theano input variables and compile generation func.
    print 'Setting up Theano...'
    real_images_var = T.TensorType('float32', [False] *
                                   len(D.input_shape))('real_images_var')
    # <class 'theano.tensor.var.TensorVariable'>
    # print type(real_images_var),real_images_var
    real_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('real_labels_var')
    fake_latents_var = T.TensorType('float32', [False] *
                                    len(G.input_shape))('fake_latents_var')
    fake_labels_var = T.TensorType(
        'float32', [False] * len(training_set.labels.shape))('fake_labels_var')
    # 带有_var的均为输入张量
    G_lrate = theano.shared(np.float32(0.0))
    D_lrate = theano.shared(np.float32(0.0))
    # share语法就是用来设定默认值的,返回复制的对象
    gen_fn = theano.function([fake_latents_var, fake_labels_var],
                             Gs.eval_nd(fake_latents_var,
                                        fake_labels_var,
                                        ignore_unused_inputs=True),
                             on_unused_input='ignore')

    # gen_fn 是一个函数,输入为:[fake_latents_var, fake_labels_var],
    #                  输出位:Gs.eval_nd(fake_latents_var, fake_labels_var, ignore_unused_inputs=True),

    #生成函数

    # Misc init.
    #读入当前分辨率
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    #lod 精细度
    initial_lod = max(
        resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    # Save example images.
    snapshot_fake_images = gen_fn(snapshot_fake_latents, snapshot_fake_labels)
    result_subdir = misc.create_result_subdir(config.result_dir,
                                              config.run_desc)
    misc.save_image_grid(example_real_images,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=drange_orig,
                         grid_size=image_grid_size)
    misc.save_image_grid(snapshot_fake_images,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_viz,
                         grid_size=image_grid_size)

    # Training loop.
    # 这里才是主训练入口
    # 注意在训练过程中不会跳出最外层while循环,因此更换分辨率等操作必然在while循环里

    #现有图片数
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0

    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []
    train_start_time = tick_start_time - resume_time
    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        #计算当前精细度
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0 / speed_factor)) / (lod_training_kimg +
                                                           lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(
                    1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                    (lod_training_kimg + lod_transition_kimg) /
                    lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2**(resolution_log2 - int(np.floor(cur_lod)))
        # 当前分辨率
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(cur_res,
                                                     tick_kimg_default)

        # Update network config.
        # 更新网络结构
        lrate_coef = misc.rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= misc.rampdown_linear(cur_nimg / 1000.0, total_kimg,
                                           rampdown_kimg)
        G_lrate.set_value(np.float32(lrate_coef * G_learning_rate_max))
        D_lrate.set_value(np.float32(lrate_coef * D_learning_rate_max))

        if hasattr(G, 'cur_lod'): G.cur_lod.set_value(np.float32(cur_lod))
        if hasattr(D, 'cur_lod'): D.cur_lod.set_value(np.float32(cur_lod))

        # Setup training func for current LOD.
        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(
            np.ceil(cur_lod))

        #print " cur_lod%f\n  min_lod %f\n new_min_lod %f\n max_lod %f\n new_max_lod %f\n"%(cur_lod,min_lod,new_min_lod,max_lod,new_max_lod)

        if min_lod != new_min_lod or max_lod != new_max_lod:
            print 'Compiling training funcs...'
            min_lod, max_lod = new_min_lod, new_max_lod

            # Pre-process reals.
            real_images_expr = real_images_var
            if dequantize_reals:
                rnd = theano.sandbox.rng_mrg.MRG_RandomStreams(
                    lasagne.random.get_rng().randint(1, 2147462579))
                epsilon_noise = rnd.uniform(size=real_images_expr.shape,
                                            low=-0.5,
                                            high=0.5,
                                            dtype='float32')
                real_images_expr = T.cast(
                    real_images_expr, 'float32'
                ) + epsilon_noise  # match original implementation of Improved Wasserstein
            real_images_expr = misc.adjust_dynamic_range(
                real_images_expr, drange_orig, drange_net)
            if min_lod > 0:  # compensate for shrink_based_on_lod
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=2)
                real_images_expr = T.extra_ops.repeat(real_images_expr,
                                                      2**min_lod,
                                                      axis=3)

            # Optimize loss.
            G_loss, D_loss, real_scores_out, fake_scores_out = evaluate_loss(
                G, D, min_lod, max_lod, real_images_expr, real_labels_var,
                fake_latents_var, fake_labels_var, **config.loss)
            G_updates = adam(G_loss,
                             G.trainable_params(),
                             learning_rate=G_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_updates = adam(D_loss,
                             D.trainable_params(),
                             learning_rate=D_lrate,
                             beta1=adam_beta1,
                             beta2=adam_beta2,
                             epsilon=adam_epsilon).items()

            D_train_fn = theano.function([
                real_images_var, real_labels_var, fake_latents_var,
                fake_labels_var
            ], [G_loss, D_loss, real_scores_out, fake_scores_out],
                                         updates=D_updates,
                                         on_unused_input='ignore')
            G_train_fn = theano.function([fake_latents_var, fake_labels_var],
                                         [],
                                         updates=G_updates + Gs.updates,
                                         on_unused_input='ignore')

        for idx in xrange(D_training_repeats):
            mb_reals, mb_labels = training_set.get_random_minibatch(
                minibatch_size,
                lod=cur_lod,
                shrink_based_on_lod=True,
                labels=True)

            print "******* test minibatch"
            print "mb_reals"
            print idx, D_training_repeats
            print mb_reals.shape, mb_labels.shape
            #print mb_reals
            print "mb_labels"
            #print mb_labels

            mb_train_out = D_train_fn(
                mb_reals, mb_labels,
                random_latents(minibatch_size, G.input_shape),
                random_labels(minibatch_size, training_set))
            cur_nimg += minibatch_size
            tick_train_out.append(mb_train_out)
        G_train_fn(random_latents(minibatch_size, G.input_shape),
                   random_labels(minibatch_size, training_set))

        # Fade in D noise if we're close to becoming unstable
        fake_score_cur = np.clip(np.mean(mb_train_out[1]), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (
            1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * (max(fake_score_avg - gdrop_lim, 0.0)**
                                       gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            D.gdrop_strength.set_value(np.float32(gdrop_strength))

        # Perform maintenance operations once per tick.
        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(
                np.mean(np.concatenate([np.asarray(v).flatten()
                                        for v in vals]))
                for vals in zip(*tick_train_out))
            tick_train_out = []

            # Print progress.
            print 'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-9.1f sec/kimg %-6.1f Dgdrop %-8.4f Gloss %-8.4f Dloss %-8.4f Dreal %-8.4f Dfake %-8.4f' % (
                (cur_tick, cur_nimg / 1000.0, cur_lod, minibatch_size,
                 misc.format_time(cur_time - train_start_time), tick_time,
                 tick_time / tick_kimg, gdrop_strength) + tick_train_avg)

            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = gen_fn(snapshot_fake_latents,
                                              snapshot_fake_labels)
                misc.save_image_grid(snapshot_fake_images,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg / 1000)),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            # Save network snapshot every N ticks.
            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg / 1000)))

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    training_set.close()
    print 'Done.'
    with open(os.path.join(result_subdir, '_training-done.txt'), 'wt'):
        pass
Exemple #11
0
def train_classifier(
    smoothing=0.999,  # Exponential running average of encoder weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    lr_mirror_augment=True,  # Enable mirror augment?
    ud_mirror_augment=False,  # Enable up-down mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=10,  # How often to export image snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False
):  # Include weight histograms in the tfevents file?

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.training_set)
    validation_set = dataset.load_dataset(data_dir=config.data_dir,
                                          verbose=True,
                                          **config.validation_set)
    network_snapshot_ticks = total_kimg // 100  # How often to export network snapshots?

    # Construct networks.
    with tf.device('/gpu:0'):
        try:
            network_pkl = misc.locate_network_pkl()
            resume_kimg, resume_time = misc.resume_kimg_time(network_pkl)
            print('Loading networks from "%s"...' % network_pkl)
            EG, D_rec, EGs = misc.load_pkl(network_pkl)
        except:
            print('Constructing networks...')
            resume_kimg = 0.0
            resume_time = 0.0
            EG = tfutil.Network('EG',
                                num_channels=training_set.shape[0],
                                resolution=training_set.shape[1],
                                label_size=training_set.label_size,
                                **config.EG)
            D_rec = tfutil.Network('D_rec',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   **config.D_rec)
            EGs = EG.clone('EGs')
        EGs_update_op = EGs.setup_as_moving_average_of(EG, beta=smoothing)
    EG.print_layers()
    D_rec.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    EG_opt = tfutil.Optimizer(name='TrainEG',
                              learning_rate=lrate_in,
                              **config.EG_opt)
    D_rec_opt = tfutil.Optimizer(name='TrainD_rec',
                                 learning_rate=lrate_in,
                                 **config.D_rec_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            EG_gpu = EG if gpu == 0 else EG.clone(EG.name + '_shadow_%d' % gpu)
            D_rec_gpu = D_rec if gpu == 0 else D_rec.clone(D_rec.name +
                                                           '_shadow_%d' % gpu)
            reals_fade_gpu, reals_orig_gpu = process_reals(
                reals_split[gpu], lod_in, lr_mirror_augment, ud_mirror_augment,
                training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('EG_loss'):
                EG_loss = tfutil.call_func_by_name(EG=EG_gpu,
                                                   D_rec=D_rec_gpu,
                                                   reals_orig=reals_orig_gpu,
                                                   labels=labels_gpu,
                                                   **config.EG_loss)
            with tf.name_scope('D_rec_loss'):
                D_rec_loss = tfutil.call_func_by_name(
                    EG=EG_gpu,
                    D_rec=D_rec_gpu,
                    D_rec_opt=D_rec_opt,
                    minibatch_size=minibatch_split,
                    reals_orig=reals_orig_gpu,
                    **config.D_rec_loss)
            EG_opt.register_gradients(tf.reduce_mean(EG_loss),
                                      EG_gpu.trainables)
            D_rec_opt.register_gradients(tf.reduce_mean(D_rec_loss),
                                         D_rec_gpu.trainables)
    EG_train_op = EG_opt.apply_updates()
    D_rec_train_op = D_rec_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, train_reals, train_labels = setup_snapshot_image_grid(
        training_set, drange_net, [450, 10], **config.grid)
    grid_size, val_reals, val_labels = setup_snapshot_image_grid(
        validation_set, drange_net, [450, 10], **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)

    train_recs, train_fingerprints, train_logits = EGs.run(
        train_reals, minibatch_size=sched.minibatch // config.num_gpus)
    train_preds = np.argmax(train_logits, axis=1)
    train_gt = np.argmax(train_labels, axis=1)
    train_acc = np.float32(np.sum(train_gt == train_preds)) / np.float32(
        len(train_gt))
    print('Training Accuracy = %f' % train_acc)

    val_recs, val_fingerprints, val_logits = EGs.run(
        val_reals, minibatch_size=sched.minibatch // config.num_gpus)
    val_preds = np.argmax(val_logits, axis=1)
    val_gt = np.argmax(val_labels, axis=1)
    val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(len(val_gt))
    print('Validation Accuracy = %f' % val_acc)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(train_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'train_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'train_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(train_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'train_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_reals[::30, :, :, :],
                         os.path.join(result_subdir, 'val_reals.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_recs[::30, :, :, :],
                         os.path.join(result_subdir, 'val_recs-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])
    misc.save_image_grid(val_fingerprints[::30, :, :, :],
                         os.path.join(result_subdir,
                                      'val_fingerrints-init.png'),
                         drange=drange_net,
                         grid_size=[15, 10])

    est_fingerprints = np.transpose(
        EGs.vars['Conv_fingerprints/weight'].eval(), axes=[3, 2, 0, 1])
    misc.save_image_grid(
        est_fingerprints,
        os.path.join(result_subdir, 'est_fingerrints-init.png'),
        drange=[np.amin(est_fingerprints),
                np.amax(est_fingerprints)],
        grid_size=[est_fingerprints.shape[0], 1])

    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        EG.setup_weight_histograms()
        D_rec.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                EG_opt.reset_optimizer_state()
                D_rec_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            tfutil.run(
                [D_rec_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run(
                [EG_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.lrate,
                    minibatch_in: sched.minibatch
                })
            tfutil.run([EGs_update_op], {})
            cur_nimg += sched.minibatch

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f resolution %-4d minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/resolution', sched.resolution),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Print accuracy.
            if cur_tick % image_snapshot_ticks == 0 or done:

                train_recs, train_fingerprints, train_logits = EGs.run(
                    train_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                train_preds = np.argmax(train_logits, axis=1)
                train_gt = np.argmax(train_labels, axis=1)
                train_acc = np.float32(np.sum(
                    train_gt == train_preds)) / np.float32(len(train_gt))
                print('Training Accuracy = %f' % train_acc)

                val_recs, val_fingerprints, val_logits = EGs.run(
                    val_reals,
                    minibatch_size=sched.minibatch // config.num_gpus)
                val_preds = np.argmax(val_logits, axis=1)
                val_gt = np.argmax(val_labels, axis=1)
                val_acc = np.float32(np.sum(val_gt == val_preds)) / np.float32(
                    len(val_gt))
                print('Validation Accuracy = %f' % val_acc)

                misc.save_image_grid(train_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'train_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(train_fingerprints[::30, :, :, :],
                                     os.path.join(
                                         result_subdir,
                                         'train_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_recs[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_recs-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])
                misc.save_image_grid(val_fingerprints[::30, :, :, :],
                                     os.path.join(result_subdir,
                                                  'val_fingerrints-final.png'),
                                     drange=drange_net,
                                     grid_size=[15, 10])

                est_fingerprints = np.transpose(
                    EGs.vars['Conv_fingerprints/weight'].eval(),
                    axes=[3, 2, 0, 1])
                misc.save_image_grid(est_fingerprints,
                                     os.path.join(result_subdir,
                                                  'est_fingerrints-final.png'),
                                     drange=[
                                         np.amin(est_fingerprints),
                                         np.amax(est_fingerprints)
                                     ],
                                     grid_size=[est_fingerprints.shape[0], 1])

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (EG, D_rec, EGs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((EG, D_rec, EGs),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #12
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    compute_fid_score=False,  # Compute FID during training once sched.lod=0.0 
    minimum_fid_kimg=0,  # Compute FID after 
    fid_snapshot_ticks=1,  # How often to compute FID
    fid_patience=2,  # When to end training based on FID
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    result_subdir="./"):
    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id != "None":
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            resume_pkl_name = os.path.splitext(
                os.path.basename(network_pkl))[0]
            try:
                resume_kimg = int(resume_pkl_name.split('-')[-1])
                print('** Setting resume kimg to', resume_kimg, flush=True)
            except:
                print('** Keeping resume kimg as:', resume_kimg, flush=True)
            print('Loading networks from "%s"...' % network_pkl, flush=True)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...', flush=True)
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...', flush=True)
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)

    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...', flush=True)
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...', flush=True)
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...', flush=True)
    # FID patience parameters:
    fid_list = []
    fid_steps = 0
    fid_stop = False
    fid_patience_step = 0

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid = compute_fid(Gs=Gs,
                                  minibatch_size=sched.minibatch,
                                  dataset_obj=training_set,
                                  iter_number=cur_nimg / 1000,
                                  lod=0.0,
                                  num_images=10000,
                                  printing=False)
                fid_list.append(fid)

            # Report progress without FID.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)),
                flush=True)
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save image snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            # Save network snapshots
            if cur_tick % network_snapshot_ticks == 0 or done or (
                    compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks == 0) and (
                        cur_nimg >= minimum_fid_kimg * 1000):
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # End training based on FID patience
            if (compute_fid_score
                    == True) and (cur_tick % fid_snapshot_ticks
                                  == 0) and (sched.lod == 0.0) and (
                                      cur_nimg >= minimum_fid_kimg * 1000):
                fid_patience_step += 1
                if len(fid_list) == 1:
                    fid_patience_step = 0
                    misc.save_pkl((G, D, Gs),
                                  os.path.join(result_subdir,
                                               'network-final-full-conv.pkl'))
                    print(
                        "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                        % (fid_list[-1], cur_nimg // 1000),
                        flush=True)
                else:
                    if fid_list[-1] < np.min(fid_list[:-1]):
                        fid_patience_step = 0
                        misc.save_pkl(
                            (G, D, Gs),
                            os.path.join(result_subdir,
                                         'network-final-full-conv.pkl'))
                        print(
                            "Save network-final-full-conv for FID: %.3f at kimg %-8.1f."
                            % (fid_list[-1], cur_nimg // 1000),
                            flush=True)
                    else:
                        print("No improvement for FID: %.3f at kimg %-8.1f." %
                              (fid_list[-1], cur_nimg // 1000),
                              flush=True)
                if fid_patience_step == fid_patience:
                    fid_stop = True
                    print("Training stopped due to FID early-stopping.",
                          flush=True)
                    cur_nimg = total_kimg * 1000

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    # Save final only if FID-Stopping has not happend:
    if fid_stop == False:
        fid = compute_fid(Gs=Gs,
                          minibatch_size=sched.minibatch,
                          dataset_obj=training_set,
                          iter_number=cur_nimg / 1000,
                          lod=0.0,
                          num_images=10000,
                          printing=False)
        print("Final FID: %.3f at kimg %-8.1f." % (fid, cur_nimg // 1000),
              flush=True)
        ### save final FID to .csv file in result_parent_dir
        csv_file = os.path.join(
            os.path.dirname(os.path.dirname(result_subdir)),
            "results_full_conv.csv")
        list_to_append = [
            result_subdir.split("/")[-2] + "/" + result_subdir.split("/")[-1],
            fid
        ]
        with open(csv_file, 'a') as f_object:
            writer_object = writer(f_object)
            writer_object.writerow(list_to_append)
            f_object.close()
        misc.save_pkl((G, D, Gs),
                      os.path.join(result_subdir,
                                   'network-final-full-conv.pkl'))
        print("Save network-final-full-conv.", flush=True)
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_progressive_gan(
    G_smoothing             = 0.999,        # Exponential running average of generator weights.
    D_repeats               = 1,            # How many times the discriminator is trained per G iteration.
    minibatch_repeats       = 4,            # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod   = True,         # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 15000,        # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,        # Enable mirror augment?
    drange_net              = [-1,1],       # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 1,            # How often to export image snapshots?
    network_snapshot_ticks  = 10,           # How often to export network snapshots?
    save_tf_graph           = False,        # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,        # Include weight histograms in the tfevents file?
    resume_run_id           = None,         # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot         = None,         # Snapshot index to resume training from, None = autodetect.
    resume_kimg             = 0.0,          # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0):         # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G)
            D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D)
            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels   = training_set.get_minibatch_tf()
        reals_split     = tf.split(reals, config.num_gpus)
        labels_split    = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tfutil.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (
                tfutil.autosummary('Progress/tick', cur_tick),
                tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                tfutil.autosummary('Progress/lod', sched.lod),
                tfutil.autosummary('Progress/minibatch', sched.minibatch),
                misc.format_time(tfutil.autosummary('Timing/total_sec', total_time)),
                tfutil.autosummary('Timing/sec_per_tick', tick_time),
                tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                tfutil.autosummary('Timing/maintenance_sec', maintenance_time)))
            tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)
    #resume_run_id = '/dresden/users/mk1391/evl/pggan_logs/logs_celeba128cc/fsg16_results_0/000-pgan-celeba-preset-v2-2gpus-fp32/network-snapshot-010211.pkl'
    resume_with_new_nets = False
    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)
            Gs = G.clone('Gs')
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            rG, rD, rGs = misc.load_pkl(network_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()

    ### pyramid draw fsg (comment out for actual training to happen)
    #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'pggan_fsg_draw.png'))
    #print('>>> done printing fsgs.')
    #return

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    ### shift reals
    print('>>> reals shape: ', grid_reals.shape)
    fc_x = 0.5
    fc_y = 0.5
    im_size = grid_reals.shape[-1]
    kernel_loc = 2.*np.pi*fc_x * np.arange(im_size).reshape((1, 1, im_size)) + \
        2.*np.pi*fc_y * np.arange(im_size).reshape((1, im_size, 1))
    kernel_cos = np.cos(kernel_loc)
    kernel_sin = np.sin(kernel_loc)
    reals_t = (grid_reals / 255.) * 2. - 1
    reals_t *= kernel_cos
    grid_reals_sh = np.rint(
        (reals_t + 1.) * 255. / 2.).clip(0, 255).astype(np.uint8)
    ### end shift reals
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    ### fft drawing
    #sys.path.insert(1, '/home/mahyar/CV_Res/ganist')
    #from fig_draw import apply_fft_win
    #data_size = 1000
    #latents = np.random.randn(data_size, *Gs.input_shapes[0][1:])
    #labels = np.zeros([latents.shape[0]] + Gs.input_shapes[1][1:])
    #g_samples = Gs.run(latents, labels, minibatch_size=sched.minibatch//config.num_gpus)
    #g_samples = g_samples.transpose(0, 2, 3, 1)
    #print('>>> g_samples shape: {}'.format(g_samples.shape))
    #apply_fft_win(g_samples, 'fft_pggan_hann.png')
    ### end fft drawing

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    ### drawing shifted real images
    misc.save_image_grid(grid_reals_sh,
                         os.path.join(result_subdir, 'reals_sh.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    ### drawing shifted fake images
    misc.save_image_grid(grid_fakes * kernel_cos,
                         os.path.join(result_subdir, 'fakes%06d_sh.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    #### True cosine fft eval
    #fft_data_size = 1000
    #im_size = training_set.shape[1]
    #freq_centers = [(64/128., 64/128.)]
    #true_samples = sample_true(training_set, fft_data_size, dtype=training_set.dtype, batch_size=32).transpose(0, 2, 3, 1) / 255. * 2. - 1.
    #true_fft, true_fft_hann, true_hist = cosine_eval(true_samples, 'true', freq_centers, log_dir=result_subdir)
    #fractal_eval(true_samples, f'koch_snowflake_true', result_subdir)

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                ### drawing shifted fake images
                misc.save_image_grid(
                    grid_fakes * kernel_cos,
                    os.path.join(result_subdir,
                                 'fakes%06d_sh.png' % (cur_nimg // 1000)),
                    drange=drange_net,
                    grid_size=grid_size)
                ### drawing fsg
                #draw_gen_fsg(Gs, 10, os.path.join(config.result_dir, 'fakes%06d_fsg_draw.png' % (cur_nimg // 1000)))
                ### Gen fft eval
                #gen_samples = sample_gen(Gs, fft_data_size).transpose(0, 2, 3, 1)
                #print(f'>>> fake_samples: max={np.amax(grid_fakes)} min={np.amin(grid_fakes)}')
                #print(f'>>> gen_samples: max={np.amax(gen_samples)} min={np.amin(gen_samples)}')
                #misc.save_image_grid(gen_samples[:25], os.path.join(result_subdir, 'fakes%06d_gsample.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
                #cosine_eval(gen_samples, f'gen_{cur_nimg//1000:06d}', freq_centers, log_dir=result_subdir, true_fft=true_fft, true_fft_hann=true_fft_hann, true_hist=true_hist)
                #fractal_eval(gen_samples, f'koch_snowflake_fakes{cur_nimg//1000:06d}', result_subdir)
            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
    'Use grabcut algorithm on the face mask to better segment the foreground',
    type=bool)
parser.add_argument('--scale_mask',
                    default=1.5,
                    help='Look over a wider section of foreground for grabcut',
                    type=float)

parser.add_argument('--use_aligned',
                    default=1,
                    help='align face before recovery',
                    type=int)

args = parser.parse_args()

# manual parameters
result_subdir = misc.create_result_subdir('results', 'inference_test')
misc.init_output_logging()

args.aligned_dir = os.path.join(result_subdir, args.aligned_dir)
args.dlatent_dir = os.path.join(result_subdir, args.dlatent_dir)
args.dlabel_dir = os.path.join(result_subdir, args.dlabel_dir)
args.generated_images_dir = os.path.join(result_subdir,
                                         args.generated_images_dir)

if os.path.exists(args.aligned_dir) == False:
    os.mkdir(args.aligned_dir)

# initialize TensorFlow
print('Initializing TensorFlow...')
env = EasyDict()  # Environment variables, set by the main program in train.py.
env.TF_CPP_MIN_LOG_LEVEL = '1'  # Print warnings and errors, but disable debug info.
Exemple #16
0
def train_gan(
    images_dir1,
    images_dir2,
    batch_size,
    img_shape               = (32,32,3),
    D_training_repeats      = 1,
    G_learning_rate_max     = 0.0010,
    D_learning_rate_max     = 0.0010,
    G_smoothing             = 0.999,
    adam_beta1              = 0.0,
    adam_beta2              = 0.99,
    adam_epsilon            = 1e-8,
    minibatch_default       = 16,
    rampup_kimg             = 40/speed_factor,
    rampdown_kimg           = 0,
    lod_initial_resolution  = 4,
    lod_training_kimg       = 400/speed_factor,
    lod_transition_kimg     = 400/speed_factor,
    total_kimg              = 10000/speed_factor,
    dequantize_reals        = False,
    gdrop_beta              = 0.9,
    gdrop_lim               = 0.5,
    gdrop_coef              = 0.2,
    gdrop_exp               = 2.0,
    drange_net              = [-1,1],
    drange_viz              = [-1,1],
    image_grid_size         = None,
    tick_kimg_default       = 50/speed_factor,
    tick_kimg_overrides     = {32:20, 64:10, 128:10, 256:5, 512:2, 1024:1},
    image_snapshot_ticks    = 1,
    network_snapshot_ticks  = 4,
    image_grid_type         = 'default',
    #resume_network          = '000-celeba/network-snapshot-000488',
    resume_network          = None,
    resume_kimg             = 0.0,
    resume_time             = 0.0):

    # if resume_network:
    #     print("Resuming weight from:"+resume_network)
    #     G = Generator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.G)
    #     D = Discriminator(num_channels=training_set.shape[3], resolution=training_set.shape[1], label_size=training_set.labels.shape[1], **config.D)
    #     G,D = load_GD_weights(G,D,os.path.join(config.result_dir,resume_network),True)
    # else:


    E_G = Encoder_Generator(num_channels=img_shape[2], resolution=img_shape[0], **config.G)
    D = Discriminator(num_channels=img_shape[2], resolution=img_shape[0], **config.D)

    E_twin_G_twin = new_batch_norm(E_G)
    D_twin = new_batch_norm(D)

    E = extract_encoder(E_G)
    E_twin = extract_encoder(E_twin_G_twin)


    E_twin_G = replace_batch_norm(E_G, E_twin_G_twin, apply='encoder')
    E_G_twin = replace_batch_norm(E_G, E_twin_G_twin, apply='generator')

    E_G_twin_E_twin = Sequential([E_G_twin, E_twin])


    E_G_D = Sequential([E_G, D])
    E_G_twin_D_twin = Sequential([E_G_twin, D_twin])

    E_twin_G_E = Sequential([E_twin_G, E])
    E_twin_G_D = Sequential([E_twin_G, D])
    E_twin_G_twin_D_twin = Sequential([E_twin_G_twin, D_twin])

    # Misc init.
    resolution_log2 = int(np.round(np.log2(E_G.output_shape[2])))
    initial_lod = max(resolution_log2 - int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    D.trainable = False
    D_twin.trainable = False
    E_G_D.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
                  loss=adversarial_loss)
    E_twin_G_D.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
                  loss=adversarial_loss)

    E_G_twin_D_twin.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
                  loss=adversarial_loss)
    E_twin_G_twin_D_twin.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
                  loss=adversarial_loss)


    D.trainable = True
    D_twin.trainable = True
    D.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
              loss=adversarial_loss)
    D_twin.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
              loss=adversarial_loss)

    E_G.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
              loss=cycle_consistency_loss)
    E_twin_G_twin.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
              loss=cycle_consistency_loss)

    E_twin_G_E.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
              loss=semantic_consistency_loss)
    E_G_twin_E_twin.compile(optimizers.Adam(lr=0.0, beta_1=adam_beta1, beta_2=adam_beta2, epsilon=adam_epsilon),
              loss=semantic_consistency_loss)


    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    tick_train_out = []

    # Set up generators

    data_generator1 = DataGenerator(images_dir=images_dir1)
    data_generator2 = DataGenerator(images_dir=images_dir2)

    generator1 = data_generator1.generate(batch_size=2, img_size=2**resolution_log2)
    generator2 = data_generator2.generate(batch_size=2, img_size=2**resolution_log2)

    real_1 = next(generator1)
    real_2 = next(generator2)

    result_subdir = misc.create_result_subdir(config.result_dir, config.run_desc)

    print("real_1.shape:", real_1.shape)
    print("real_2.shape:", real_2.shape)

    misc.save_image_grid_twin(real_1, real_2, os.path.join(result_subdir, 'reals.png'))

    nimg_h = 0

    valid = np.ones((batch_size, 1, 1, 1))
    fake = np.zeros((batch_size, 1, 1, 1))

    while cur_nimg < total_kimg * 1000:
        
        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0/speed_factor)) / (lod_training_kimg + lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(1.0 + (np.fmod(tlod, 1.0) - 1.0) * (lod_training_kimg + lod_transition_kimg) / lod_transition_kimg, 0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2 ** (resolution_log2 - int(np.floor(cur_lod)))
        tick_duration_kimg = tick_kimg_overrides.get(cur_res, tick_kimg_default)


        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg / 1000.0, total_kimg, rampdown_kimg)

        models = [E_G_D, E_twin_G_D, E_G_twin_D_twin, E_twin_G_twin_D_twin, D, D_twin, E_G, E_twin_G_twin,
                  E_twin_G_E, E_G_twin_E_twin]

        learning_rate_max = 0.001

        for model in models:

            K.set_value(model.optimizer.lr, np.float32(lrate_coef * learning_rate_max))
            if hasattr(model, 'cur_lod'): K.set_value(model.cur_lod,np.float32(cur_lod))

        new_min_lod, new_max_lod = int(np.floor(cur_lod)), int(np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        generator1 = data_generator1.generate(batch_size=batch_size, img_size=cur_res, min_lod=min_lod)
        generator2 = data_generator2.generate(batch_size=batch_size, img_size=cur_res, min_lod=min_lod)

        # create image to check the training process
        real_1 = next(generator1)
        real_2 = next(generator2)

        fake_2 = E_G_twin.predict_on_batch(real_1)
        fake_1 = E_twin_G.predict_on_batch(real_2)
        
        misc.save_image_grid_twin(real_1, fake_2, os.path.join(result_subdir, 'fakes_dog%06d.png' % (cur_nimg / 1000)))
        misc.save_image_grid_twin(real_2, fake_1, os.path.join(result_subdir, 'fakes_celeb%06d.png' % (cur_nimg / 1000)))

        cur_nimg += batch_size
        ################################################################################################################
        # train D

        images1 = next(generator1)
        img_fakes = E_G.predict_on_batch([images1])

        d_true = D.train_on_batch(images1, valid)
        d_fake = D.train_on_batch(img_fakes, fake)

        #train E_G_D
        g_loss = E_G_D.train_on_batch(images1, valid)
        #print ("%d [D loss: %f] [G loss: %f]" % (cur_nimg, np.mean(d_true, d_fake), g_loss))

        ################################################################################################################
        # train D_twin
        images2 = next(generator2)
        img_fakes = E_twin_G_twin.predict_on_batch([images2])

        d_true = D_twin.train_on_batch(images2, valid)
        d_fake = D_twin.train_on_batch(img_fakes, fake)

        #train E_twin_G_twin_D_twin
        g_loss = E_twin_G_twin_D_twin.train_on_batch(images2, valid)
        #print ("%d [D loss: %f] [G loss: %f]" % (cur_nimg, np.mean(d_true, d_fake), g_loss))

        #train EG on cycle consistency
        E_G.train_on_batch(images1, images1)

        # train E_twin_G_twin on cycle consistency
        E_twin_G_twin.train_on_batch(images2, images2)

        # train E_G_twin an D_twin on discriminator loss
        img_fakes = E_G_twin.predict_on_batch([images1])

        d_true = D_twin.train_on_batch(images2, valid)
        d_fake = D_twin.train_on_batch(img_fakes, fake)
        g_loss = E_G_twin_D_twin.train_on_batch(images1, valid)

        # train E_twin_G and D on discriminator loss
        img_fakes = E_twin_G.predict_on_batch([images2])

        d_true = D.train_on_batch(images1, valid)
        d_fake = D.train_on_batch(img_fakes, fake)
        g_loss = E_twin_G_D.train_on_batch(images2, valid)

        # train E_G_twin on semantic consistency

        semantic = E.predict_on_batch(images1)
        E_G_twin_E_twin.train_on_batch(images1, semantic)

        # train E_twin_G on semantic consistency

        semantic = E_twin.predict_on_batch(images2)
        E_twin_G_E.train_on_batch(images2, semantic)


        1/0

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + fake_score_cur * (1.0 - gdrop_beta)

        if cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000 or cur_nimg >= total_kimg * 1000:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            tick_start_time = cur_time
            tick_train_avg = tuple(np.mean(np.concatenate([np.asarray(v).flatten() for v in vals])) for vals in zip(*tick_train_out))
            tick_train_out = []



            # Visualize generated images.
            if cur_tick % image_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg / 1000)), drange=drange_viz, grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or cur_nimg >= total_kimg * 1000:
                save_GD_weights(G,D,os.path.join(result_subdir, 'network-snapshot-%06d' % (cur_nimg / 1000)))


    save_GD(G,D,os.path.join(result_subdir, 'network-final'))
    training_set.close()
    print('Done.')
Exemple #17
0
def train_progressive_gan(
    G_smoothing=0.999,  # Exponential running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    maintenance_start_time = time.time()
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **config.dataset)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs, E = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tfutil.Network('G',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.G)
            D = tfutil.Network('D',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.D)

            E = tfutil.Network('E',
                               num_channels=training_set.shape[0],
                               resolution=training_set.shape[1],
                               label_size=training_set.label_size,
                               **config.E)

            Gs = G.clone('Gs')
        Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
    G.print_layers()
    D.print_layers()
    E.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // config.num_gpus
        reals, labels = training_set.get_minibatch_tf()
        reals_split = tf.split(reals, config.num_gpus)
        labels_split = tf.split(labels, config.num_gpus)
    G_opt = tfutil.Optimizer(name='TrainG',
                             learning_rate=lrate_in,
                             **config.G_opt)
    D_opt = tfutil.Optimizer(name='TrainD',
                             learning_rate=lrate_in,
                             **config.D_opt)
    E_opt = tfutil.Optimizer(name='TrainE',
                             learning_rate=lrate_in,
                             **config.E_opt)
    for gpu in range(config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in),
                tf.assign(E_gpu.find_var('lod'), lod_in)
            ]
            reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment,
                                      training_set.dynamic_range, drange_net)
            labels_gpu = labels_split[gpu]
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **config.G_loss)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals_gpu,
                    labels=labels_gpu,
                    **config.D_loss)
            with tf.name_scope('E_loss'), tf.control_dependencies(
                    lod_assign_ops):
                E_loss = tfutil.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    E=E_gpu,
                    opt=E_opt,
                    training_set=training_set,
                    reals=reals_gpu,
                    minibatch_size=minibatch_split,
                    **config.E_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
            E_opt.register_gradients(tf.reduce_mean(E_loss), E_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    E_train_op = E_opt.apply_updates()

    #sys.exit(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(
        G, training_set, **config.grid)
    sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        minibatch_size=sched.minibatch // config.num_gpus)

    print('Setting up result dir...')
    result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
    misc.save_image_grid(grid_reals,
                         os.path.join(result_subdir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(result_subdir, 'fakes%06d.png' % 0),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(result_subdir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()

    print('Training...')
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()
    train_start_time = tick_start_time - resume_time
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:

        # Choose training parameters and configure training ops.
        sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
        training_set.configure(sched.minibatch, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
                E_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for repeat in range(minibatch_repeats):
            for _ in range(D_repeats):
                tfutil.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tfutil.run(
                [G_train_op, E_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time
            total_time = cur_time - train_start_time
            maintenance_time = tick_start_time - maintenance_start_time
            maintenance_start_time = cur_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f'
                % (tfutil.autosummary('Progress/tick', cur_tick),
                   tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
                   tfutil.autosummary('Progress/lod', sched.lod),
                   tfutil.autosummary('Progress/minibatch', sched.minibatch),
                   misc.format_time(
                       tfutil.autosummary('Timing/total_sec', total_time)),
                   tfutil.autosummary('Timing/sec_per_tick', tick_time),
                   tfutil.autosummary('Timing/sec_per_kimg',
                                      tick_time / tick_kimg),
                   tfutil.autosummary('Timing/maintenance_sec',
                                      maintenance_time)))
            tfutil.autosummary('Timing/total_hours',
                               total_time / (60.0 * 60.0))
            tfutil.autosummary('Timing/total_days',
                               total_time / (24.0 * 60.0 * 60.0))
            tfutil.save_summaries(summary_log, cur_nimg)

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    minibatch_size=sched.minibatch //
                                    config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         result_subdir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

                misc.save_all_res(training_set.shape[1],
                                  Gs,
                                  result_subdir,
                                  50,
                                  minibatch_size=sched.minibatch //
                                  config.num_gpus)

            if cur_tick % network_snapshot_ticks == 0 or done:
                misc.save_pkl(
                    (G, D, Gs, E),
                    os.path.join(
                        result_subdir,
                        'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))

            # Record start time of the next tick.
            tick_start_time = time.time()

    # Write final results.
    misc.save_pkl((G, D, Gs, E),
                  os.path.join(result_subdir, 'network-final.pkl'))
    summary_log.close()
    open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
Exemple #18
0
def train_gan(
    separate_funcs=False,
    D_training_repeats=1,
    G_learning_rate_max=0.0010,
    D_learning_rate_max=0.0010,
    G_smoothing=0.999,
    adam_beta1=0.0,
    adam_beta2=0.99,
    adam_epsilon=1e-8,
    minibatch_default=16,
    minibatch_overrides={},
    rampup_kimg=40/speed_factor,
    rampdown_kimg=0,
    lod_initial_resolution=4,
    lod_training_kimg=400/speed_factor,
    lod_transition_kimg=400/speed_factor,
    total_kimg=10000/speed_factor,
    dequantize_reals=False,
    gdrop_beta=0.9,
    gdrop_lim=0.5,
    gdrop_coef=0.2,
    gdrop_exp=2.0,
    drange_net=[-1, 1],
    drange_viz=[-1, 1],
    image_grid_size=None,
    tick_kimg_default=50/speed_factor,
    tick_kimg_overrides={32: 20, 64: 10, 128: 10, 256: 5, 512: 2, 1024: 1},
    image_snapshot_ticks=1,
    network_snapshot_ticks=4,
    image_grid_type='default',
    # resume_network          = '000-celeba/network-snapshot-000488',
    resume_network=None,
    resume_kimg=0.0,
        resume_time=0.0):

    training_set, drange_orig = load_dataset()

    G, G_train,\
        D, D_train = build_trainable_model(resume_network,
                                           training_set.shape[3],
                                           training_set.shape[1],
                                           training_set.labels.shape[1],
                                           adam_beta1, adam_beta2,
                                           adam_epsilon)

    # Misc init.
    resolution_log2 = int(np.round(np.log2(G.output_shape[2])))
    initial_lod = max(resolution_log2 -
                      int(np.round(np.log2(lod_initial_resolution))), 0)
    cur_lod = 0.0
    min_lod, max_lod = -1.0, -2.0
    fake_score_avg = 0.0

    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    tick_start_time = time.time()

    train_start_time = tick_start_time - resume_time

    if image_grid_type == 'default':
        if image_grid_size is None:
            w, h = G.output_shape[1], G.output_shape[2]
            print('w:%d,h:%d' % (w, h))
            image_grid_size = (np.clip(int(1920 // w), 3, 16).astype('int'),
                               np.clip(1080 / h, 2, 16).astype('int'))

        print('image_grid_size:', image_grid_size)

        example_real_images, snapshot_fake_labels = \
            training_set.get_random_minibatch_channel_last(
                np.prod(image_grid_size), labels=True)

        snapshot_fake_latents = random_latents(
            np.prod(image_grid_size), G.input_shape)
    else:
        raise ValueError('Invalid image_grid_type', image_grid_type)

    result_subdir = misc.create_result_subdir(
        config.result_dir, config.run_desc)
    result_subdir = Path(result_subdir)

    print('example_real_images.shape:', example_real_images.shape)
    misc.save_image_grid(example_real_images,
                         str(result_subdir / 'reals.png'),
                         drange=drange_orig, grid_size=image_grid_size)

    snapshot_fake_latents = random_latents(
        np.prod(image_grid_size), G.input_shape)
    snapshot_fake_images = G.predict_on_batch(snapshot_fake_latents)
    misc.save_image_grid(snapshot_fake_images,
                         str(result_subdir/f'fakes{cur_nimg//1000:06}.png'),
                         drange=drange_viz, grid_size=image_grid_size)

    # nimg_h = 0

    while cur_nimg < total_kimg * 1000:

        # Calculate current LOD.
        cur_lod = initial_lod
        if lod_training_kimg or lod_transition_kimg:
            tlod = (cur_nimg / (1000.0/speed_factor)) / \
                (lod_training_kimg + lod_transition_kimg)
            cur_lod -= np.floor(tlod)
            if lod_transition_kimg:
                cur_lod -= max(1.0 + (np.fmod(tlod, 1.0) - 1.0) *
                               (lod_training_kimg + lod_transition_kimg) /
                               lod_transition_kimg,
                               0.0)
            cur_lod = max(cur_lod, 0.0)

        # Look up resolution-dependent parameters.
        cur_res = 2 ** (resolution_log2 - int(np.floor(cur_lod)))
        minibatch_size = minibatch_overrides.get(cur_res, minibatch_default)
        tick_duration_kimg = tick_kimg_overrides.get(
            cur_res, tick_kimg_default)

        # Update network config.
        lrate_coef = rampup(cur_nimg / 1000.0, rampup_kimg)
        lrate_coef *= rampdown_linear(cur_nimg /
                                      1000.0, total_kimg, rampdown_kimg)

        K.set_value(G.optimizer.lr, np.float32(
            lrate_coef * G_learning_rate_max))
        K.set_value(G_train.optimizer.lr, np.float32(
            lrate_coef * G_learning_rate_max))

        K.set_value(D_train.optimizer.lr, np.float32(
            lrate_coef * D_learning_rate_max))
        if hasattr(G_train, 'cur_lod'):
            K.set_value(G_train.cur_lod, np.float32(cur_lod))
        if hasattr(D_train, 'cur_lod'):
            K.set_value(D_train.cur_lod, np.float32(cur_lod))

        new_min_lod, new_max_lod = int(
            np.floor(cur_lod)), int(np.ceil(cur_lod))
        if min_lod != new_min_lod or max_lod != new_max_lod:
            min_lod, max_lod = new_min_lod, new_max_lod

        # train D
        d_loss = None
        for idx in range(D_training_repeats):
            mb_reals, mb_labels = \
                training_set.get_random_minibatch_channel_last(
                    minibatch_size, lod=cur_lod, shrink_based_on_lod=True,
                    labels=True)
            mb_latents = random_latents(minibatch_size, G.input_shape)

            # compensate for shrink_based_on_lod
            if min_lod > 0:
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=1)
                mb_reals = np.repeat(mb_reals, 2**min_lod, axis=2)

            mb_fakes = G.predict_on_batch([mb_latents])

            epsilon = np.random.uniform(0, 1, size=(minibatch_size, 1, 1, 1))
            interpolation = epsilon*mb_reals + (1-epsilon)*mb_fakes
            mb_reals = misc.adjust_dynamic_range(
                mb_reals, drange_orig, drange_net)
            d_loss, d_diff, d_norm = \
                D_train.train_on_batch([mb_fakes, mb_reals, interpolation],
                                       [np.ones((minibatch_size, 1, 1, 1)),
                                        np.ones((minibatch_size, 1))])
            d_score_real = D.predict_on_batch(mb_reals)
            d_score_fake = D.predict_on_batch(mb_fakes)
            print('real score: %d fake score: %d' %
                  (np.mean(d_score_real), np.mean(d_score_fake)))
            cur_nimg += minibatch_size

        # train G
        mb_latents = random_latents(minibatch_size, G.input_shape)

        g_loss = G_train.train_on_batch(
            [mb_latents], (-1)*np.ones((mb_latents.shape[0], 1, 1, 1)))

        print('%d [D loss: %f] [G loss: %f]' % (cur_nimg, d_loss, g_loss))

        fake_score_cur = np.clip(np.mean(d_loss), 0.0, 1.0)
        fake_score_avg = fake_score_avg * gdrop_beta + \
            fake_score_cur * (1.0 - gdrop_beta)
        gdrop_strength = gdrop_coef * \
            (max(fake_score_avg - gdrop_lim, 0.0) ** gdrop_exp)
        if hasattr(D, 'gdrop_strength'):
            K.set_value(D.gdrop_strength, np.float32(gdrop_strength))

        is_complete = cur_nimg >= total_kimg * 1000
        is_generate_a_lot = \
            cur_nimg >= tick_start_nimg + tick_duration_kimg * 1000

        if is_generate_a_lot or is_complete:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = cur_time - tick_start_time

            print(f'tick time: {tick_time}')
            print(f'tick image: {tick_kimg}k')
            tick_start_time = cur_time

            # Visualize generated images.
            is_image_snapshot_ticks = cur_tick % image_snapshot_ticks == 0
            if is_image_snapshot_ticks or is_complete:
                snapshot_fake_images = G.predict_on_batch(
                    snapshot_fake_latents)
                misc.save_image_grid(snapshot_fake_images,
                                     str(result_subdir /
                                         f'fakes{cur_nimg // 1000:06}.png'),
                                     drange=drange_viz,
                                     grid_size=image_grid_size)

            if cur_tick % network_snapshot_ticks == 0 or is_complete:
                save_GD_weights(G, D, str(
                    result_subdir / f'network-snapshot-{cur_nimg // 1000:06}'))
        break
    save_GD(G, D, str(result_subdir/'network-final'))
    training_set.close()
    print('Done.')

    train_complete_time = time.time()-train_start_time
    print(f'training time: {train_complete_time}')