Example #1
0
def set_optimizer(cN, lrate_in, batch_multiplier, lazy_regularization = True, clip = None):
    args = dict(cN.opt_args)
    args["batch_multiplier"] = batch_multiplier
    args["learning_rate"] = lrate_in
    if lazy_regularization:
        mb_ratio = cN.reg_interval / (cN.reg_interval + 1)
        args["learning_rate"] *= mb_ratio
        if "beta1" in args: args["beta1"] **= mb_ratio
        if "beta2" in args: args["beta2"] **= mb_ratio
    cN.opt = tflib.Optimizer(name = f"Loss{cN.name}", clip = clip, **args)
    cN.reg_opt = tflib.Optimizer(name = f"Reg{cN.name}", share = cN.opt, clip = clip, **args)
Example #2
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)
    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.compat.v1.placeholder(tf.float32,
                                            name='lrate_in',
                                            shape=[])

        #print("DEBUG train:", "dataset iter got called")
        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        print(len(noisy_input_split), noisy_input_split)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)
        # Split [?, 3, 256, 256] across num_gpus over axis 0 (i.e. the batch)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)
    radii = np.arange(128).reshape(128, 1)  #image size 256, binning = 3
    radial_masks = np.apply_along_axis(radial_mask, 1, radii, 128, 128,
                                       np.arange(0, 256), np.arange(0, 256),
                                       20)
    print("RN SHAPE!!!!!!!!!!:", radial_masks.shape)
    radial_masks = np.expand_dims(radial_masks, 1)  # (128, 1, 256, 256)
    #radial_masks = np.squeeze(np.stack((radial_masks,) * 3, -1)) # 43, 3, 256, 256
    #radial_masks = radial_masks.transpose([0, 3, 1, 2])
    radial_masks = radial_masks.astype(np.complex64)
    radial_masks = tf.expand_dims(radial_masks, 1)

    rn = tf.compat.v1.placeholder_with_default(radial_masks,
                                               [128, None, 1, 256, 256])
    rn_split = tf.split(rn, submit_config.num_gpus, axis=1)
    freq_nyq = int(np.floor(int(256) / 2.0))

    spatial_freq = radii.astype(np.float32) / freq_nyq
    spatial_freq = spatial_freq / max(spatial_freq)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised_1 = net_gpu.get_output_for(noisy_input_split[gpu])
            denoised_2 = net_gpu.get_output_for(noisy_target_split[gpu])
            print(noisy_input_split[gpu].get_shape(),
                  rn_split[gpu].get_shape())
            if noise2noise:
                meansq_error = fourier_ring_correlation(
                    noisy_target_split[gpu], denoised_1, rn_split[gpu],
                    spatial_freq) - fourier_ring_correlation(
                        noisy_target_split[gpu] - denoised_2,
                        noisy_input_split[gpu] - denoised_1, rn_split[gpu],
                        spatial_freq)
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            #tf.summary.histogram(name, var)
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.compat.v1.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.compat.v1.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break
        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()
            print("DEBUG TRAIN!", noisy_input.dtype, noisy_input[0][0].dtype)
            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.tif".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.tif".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.tif".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

            save_snapshot(submit_config, net, str(i))
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
Example #3
0
def training_loop(
    G_args                  = {},       # Options for generator network.
    D_args                  = {},       # Options for discriminator network.
    G_opt_args              = {},       # Options for generator optimizer.
    D_opt_args              = {},       # Options for discriminator optimizer.
    G_loss_args             = {},       # Options for generator loss.
    D_loss_args             = {},       # Options for discriminator loss.
    dataset_args            = {},       # Options for dataset.load_dataset().
    sched_args              = {},       # Options for train.TrainingSchedule.
    grid_args               = {},       # Options for train.setup_snapshot_image_grid().
    setname                 = None,   # Model name 
    tf_config               = {},       # Options for tflib.init_tf().
    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.
    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.
    lazy_regularization     = True,     # Perform regularization as a separate training step?
    G_reg_interval          = 4,        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval          = 16,       # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    mirror_augment_v        = False,  # Enable mirror augment vertically?
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_pkl              = 'latest',     # Network pickle to resume training from, None = train from scratch.
    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0,      # Assumed wallclock time at the beginning. Affects reporting.
    restore_partial_fn      = None,   # Filename of network for partial restore
    resume_with_new_nets    = False):   # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(verbose=True, **dataset_args)
    # custom resolution - for saved model name below
    resolution = training_set.resolution
    if training_set.init_res != [4,4]:
        init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1])
    else:
        init_res_str = ''
    ext = 'png' if training_set.shape[0] == 4 else 'jpg'
    print(' model base resolution', resolution)
    
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
    misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print(' Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            if resume_pkl == 'latest':
                resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root)
            elif resume_pkl == 'restore_partial':
                print(' Restore partially...')
                # Initialize networks
                G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
                D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args)
                Gs = G.clone('Gs')
                # Load pre-trained networks
                assert restore_partial_fn != None
                G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb'))
                # Restore (subset of) pre-trained weights (only parameters that match both name and shape)
                G.copy_compatible_trainables_from(G_partial)
                D.copy_compatible_trainables_from(D_partial)
                Gs.copy_compatible_trainables_from(Gs_partial)
            else:
                if resume_pkl is not None and resume_kimg == 0:
                    resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl)
                print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg))
                rG, rD, rGs = misc.load_pkl(resume_pkl)
                if resume_with_new_nets:
                    G.copy_vars_from(rG)
                    D.copy_vars_from(rD)
                    Gs.copy_vars_from(rGs)
                else:
                    G, D, Gs = rG, rD, rGs
                
    # Print layers if needed and generate initial image snapshot
    # G.print_layers(); D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size)

    # Setup training inputs.
    print(' Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in               = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in             = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in    = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[])
        minibatch_gpu_in     = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus)
        Gs_beta              = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args)
                reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
                labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net)
                reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables)
                if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg))
    dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu) # , sched.lod
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu}
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

            if sched.lod == 0:
                left_kimg = total_kimg - cur_nimg / 1000
                left_sec = left_kimg * tick_time / tick_kimg
                finaltime = time.asctime(time.localtime(cur_time + left_sec))
                msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16])
            else:
                msg_final = ''

            # Report progress.
            # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % (
            print('tick %-4d kimg %-6.1f time %-8s  %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                # autosummary('Progress/lod', sched.lod),
                # autosummary('Progress/minibatch', sched.minibatch_size),
                # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                msg_final,
                autosummary('Timing/min_per_tick', tick_time / 60),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                # autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30),
                sched.G_lrate))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size)
            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000)))

            # Update summaries and RunContext.
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str)))
    misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str)))

    # All done.
    summary_log.close()
    training_set.close()
Example #4
0
def training_loop(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    loss_args={},  # Options for loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for metrics.
    metric_args={},  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    ema_start_kimg=None,  # Start of the exponential moving average. Default to the half-life period.
    G_ema_kimg=10,  # Half-life of the exponential moving average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=False,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=4,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=2,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=1,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False
):  # Construct new networks according to G_args and D_args before resuming training?

    if ema_start_kimg is None:
        ema_start_kimg = G_ema_kimg

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(verbose=True, **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            resume_networks = misc.load_pkl(resume_pkl)
            rG, rD, rGs = resume_networks
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G, D, Gs = rG, rD, rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[])
        D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta_mul_in = tf.placeholder(tf.float32,
                                        name='Gs_beta_in',
                                        shape=[])
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32),
                              G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    G_opt_args['learning_rate'] = G_lrate_in
    D_opt_args['learning_rate'] = D_lrate_in
    for args in [G_opt_args, D_opt_args]:
        args['minibatch_multiplier'] = minibatch_multiplier
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)

    # Build training graph for each GPU.
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            with tf.name_scope('DataFetch'):
                reals_read, labels_read = training_set.get_minibatch_tf()
                reals_read = process_reals(reals_read, lod_in, mirror_augment,
                                           training_set.dynamic_range,
                                           drange_net)

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('loss'):
                    G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        real_labels=labels_read,
                        **loss_args)

            # Register gradients.
            if not lazy_regularization:
                if D_reg is not None:
                    D_loss += D_reg
            else:
                if D_reg is not None:
                    D_loss = tf.cond(run_D_reg_in,
                                     lambda: D_loss + D_reg * D_reg_interval,
                                     lambda: D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    Gs_update_op = Gs.setup_as_moving_average_of(G,
                                                 beta=Gs_beta * Gs_beta_mul_in)
    with tf.control_dependencies([Gs_update_op]):
        G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list, **metric_args)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            G_lrate_in: sched.G_lrate,
            D_lrate_in: sched.D_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0,
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            feed_dict[run_D_reg_in] = run_D_reg
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            for _ in rounds:
                tflib.run(G_train_op, feed_dict)
                tflib.run(D_train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
def training_loop_vc(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    I_args={},  # Options for infogan-head/vcgan-head network.
    I_info_args={},  # Options for infogan-head/vcgan-head network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    use_info_gan=False,  # Whether to use info-gan.
    use_vc_head=False,  # Whether to use vc-head.
    use_vc_head_with_cls=False,  # Whether to use classification in discriminator.
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,  # Construct new networks according to G_args and D_args before resuming training?
    traversal_grid=False,  # Used for disentangled representation learning.
    n_discrete=3,  # Number of discrete latents in model.
    n_continuous=4,  # Number of continuous latents in model.
    n_samples_per=10):  # Number of samples for each line in traversal.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            if use_info_gan or use_vc_head or use_vc_head_with_cls:
                I = tflib.Network('I',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **I_args)
                if use_vc_head_with_cls:
                    I_info = tflib.Network('I_info',
                                           num_channels=training_set.shape[0],
                                           resolution=training_set.shape[1],
                                           label_size=training_set.label_size,
                                           **I_info_args)

            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if use_info_gan or use_vc_head:
                rG, rD, rI, rGs = misc.load_pkl(resume_pkl)
            elif use_vc_head_with_cls:
                rG, rD, rI, rI_info, rGs = misc.load_pkl(resume_pkl)
            else:
                rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                if use_info_gan or use_vc_head or use_vc_head_with_cls:
                    I.copy_vars_from(rI)
                    if use_vc_head_with_cls:
                        I_info.copy_vars_from(rI_info)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                if use_info_gan or use_vc_head or use_vc_head_with_cls:
                    I = rI
                    if use_vc_head_with_cls:
                        I_info = rI_info
                Gs = rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    if use_info_gan or use_vc_head or use_vc_head_with_cls:
        I.print_layers()
        if use_vc_head_with_cls:
            I_info.print_layers()
    # pdb.set_trace()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    if traversal_grid:
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels)
    else:
        grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    print('grid_latents.shape:', grid_latents.shape)
    print('grid_labels.shape:', grid_labels.shape)
    # pdb.set_trace()
    grid_fakes, _ = Gs.run(grid_latents,
                           grid_labels,
                           is_validation=True,
                           minibatch_size=sched.minibatch_gpu,
                           randomize_noise=False)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            if use_info_gan or use_vc_head or use_vc_head_with_cls:
                I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow')
                if use_vc_head_with_cls:
                    I_info_gpu = I_info if gpu == 0 else I_info.clone(
                        I_info.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    if use_info_gan or use_vc_head:
                        G_loss, G_reg, I_loss, _ = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            I=I_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                    elif use_vc_head_with_cls:
                        G_loss, G_reg, I_loss, I_info_loss = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            I=I_gpu,
                            I_info=I_info_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                    else:
                        G_loss, G_reg = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)
            # print('G_gpu.trainables:', G_gpu.trainables)
            # print('D_gpu.trainables:', D_gpu.trainables)
            # print('I_gpu.trainables:', I_gpu.trainables)
            if use_info_gan or use_vc_head:
                GI_gpu_trainables = collections.OrderedDict(
                    list(G_gpu.trainables.items()) +
                    list(I_gpu.trainables.items()))
                G_opt.register_gradients(tf.reduce_mean(G_loss + I_loss),
                                         GI_gpu_trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
                # G_opt.register_gradients(tf.reduce_mean(I_loss),
                # GI_gpu_trainables)
                # D_opt.register_gradients(tf.reduce_mean(I_loss),
                # D_gpu.trainables)
            elif use_vc_head_with_cls:
                GIIinfo_gpu_trainables = collections.OrderedDict(
                    list(G_gpu.trainables.items()) +
                    list(I_gpu.trainables.items()) +
                    list(I_info_gpu.trainables.items()))
                G_opt.register_gradients(
                    tf.reduce_mean(G_loss + I_loss + I_info_loss),
                    GIIinfo_gpu_trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
            else:
                G_opt.register_gradients(tf.reduce_mean(G_loss),
                                         G_gpu.trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)

            # if use_info_gan:
            # # INFO-GAN-HEAD loss
            # G_opt.register_gradients(tf.reduce_mean(I_loss),
            # G_gpu.trainables)
            # G_opt.register_gradients(tf.reduce_mean(I_loss),
            # I_gpu.trainables)
            # D_opt.register_gradients(tf.reduce_mean(I_loss),
            # D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
        if use_info_gan or use_vc_head or use_vc_head_with_cls:
            I.setup_weight_histograms()
            if use_vc_head_with_cls:
                I_info.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes, _ = Gs.run(grid_latents,
                                       grid_labels,
                                       is_validation=True,
                                       minibatch_size=sched.minibatch_gpu,
                                       randomize_noise=False)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                if use_info_gan or use_vc_head:
                    misc.save_pkl((G, D, I, Gs), pkl)
                elif use_vc_head_with_cls:
                    misc.save_pkl((G, D, I, I_info, Gs), pkl)
                else:
                    misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    if use_info_gan or use_vc_head:
        misc.save_pkl((G, D, I, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))
    elif use_vc_head_with_cls:
        misc.save_pkl((G, D, I, I_info, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))
    else:
        misc.save_pkl((G, D, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Example #6
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    # noinspection PyTypeChecker
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            if noise2noise:
                meansq_error = tf.reduce_mean(
                    tf.square(noisy_target_split[gpu] - denoised))
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # ***********************************
    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:
            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.png".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec',
                                   (time_train / eval_interval) *
                                   (iteration_count - i))),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

        # Training epoch
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    # End of training
    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
Example #7
0
def training_loop_refinement(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=True,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            _rG, _rD, rGs = misc.load_pkl(resume_pkl)
            del _rD, _rG
            if resume_with_new_nets:
                G.copy_vars_from(rGs)
                Gs.copy_vars_from(rGs)
                del rGs
            else:
                G = rG
                Gs = rGs

    # Set constant noise input for both G and Gs
    if G_args.get("randomize_noise", None) == False:
        noise_vars = [
            var for name, var in G.components.synthesis.vars.items()
            if name.startswith('noise')
        ]
        rnd = np.random.RandomState(123)
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]

        noise_vars = [
            var for name, var in Gs.components.synthesis.vars.items()
            if name.startswith('noise')
        ]
        rnd = np.random.RandomState(123)
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]

    # TESTS
    # from PIL import Image
    # reals, latents = training_set.get_minibatch_np(4)
    # reals = np.transpose(reals, [0, 2, 3, 1])
    # Image.fromarray(reals[0], 'RGB').save("test_reals.png")

    # labels = training_set.get_random_labels_np(4)
    # Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs)
    # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png")
    # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs)
    # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png")

    # Print layers and generate initial image snapshot.
    G.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)

    # Freeze layers
    G_args.freeze_layers = list(G_args.get("freeze_layers", []))

    def freeze_vars(gen, verbose=True):
        assert len(G_args.freeze_layers) > 0
        for name in list(gen.trainables.keys()):
            if any(layer in name for layer in G_args.freeze_layers):
                del gen.trainables[name]
                if verbose: print(f"Freezed {name}")

    # Build training graph for each GPU.
    data_fetch_ops = []
    loss_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False)

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=None,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        latents=labels_read,
                        **G_loss_args)
                    loss_ops.append(G_loss)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0))
    G_train_op = G_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    loss_per_batch_sum = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        tflib.run(data_fetch_op, feed_dict)
        ### TEST
        # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5]
        # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1])
        # reals = reals_read
        ### TEST
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                loss, _ = tflib.run([loss_op, G_train_op], feed_dict)
                # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict)
                tflib.run([data_fetch_op], feed_dict)
                # print(f"loss_tf  {np.mean(loss)}")
                # print(f"loss_np  {np.mean(np.square(reals - fakes))}")
                # print(f"loss_abs {np.mean(np.abs(reals - fakes))}")

                loss_per_batch_sum += loss
                #### TEST ####
                # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0:
                #     from PIL import Image
                #     reals = np.transpose(reals, [0, 2, 3, 1])
                #     fakes = np.transpose(fakes, [0, 2, 3, 1])
                #     diff = np.abs(reals - fakes)
                #     print(diff.min(), diff.max())
                #     for idx, (fake, real) in enumerate(zip(fakes, reals)):
                #         fake -= fake.min()
                #         fake /= fake.max()
                #         fake *= 255
                #         fake = fake.astype(np.uint8)
                #         Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png")
                #         real -= real.min()
                #         real /= real.max()
                #         real *= 255
                #         real = real.astype(np.uint8)
                #         Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png")
                ####
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([Gs_update_op], feed_dict)

            # Slow path with gradient accumulation. FIXME: Probably wrong
            else:
                for _round in rounds:
                    loss, _, _ = tflib.run(
                        [loss_op, G_train_op, data_fetch_op], feed_dict)
                    loss_per_batch_sum += loss / len(rounds)
                    if run_G_reg:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time
            tick_loss = loss_per_batch_sum * sched.minibatch_size / (
                tick_kimg * 1000)
            loss_per_batch_sum = 0

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   autosummary('Progress/loss_per_px', tick_loss),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, None, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Example #8
0
def training_loop(
        submit_config,
        Encoder_args={},
        E_opt_args={},
        D_opt_args={},
        E_loss_args={},
        D_loss_args={},
        lr_args=EasyDict(),
        tf_config={},
        dataset_args=EasyDict(),
        decoder_pkl=EasyDict(),
        drange_data=[0, 255],
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        mirror_augment=False,
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        max_iters=150000,
        E_smoothing=0.999):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        real_train = tf.placeholder(tf.float32, [
            submit_config.batch_size, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                    name='real_image_train')
        real_test = tf.placeholder(tf.float32, [
            submit_config.batch_size_test, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                   name='real_image_test')
        real_split = tf.split(real_train,
                              num_or_size_splits=submit_config.num_gpus,
                              axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs, NE = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')
                        [0]) // submit_config.batch_size
        else:
            print('Constructing networks...')
            G, D, Gs, NE = misc.load_pkl(decoder_pkl.decoder_pkl)
            E = tflib.Network('E',
                              size=submit_config.image_size,
                              filter=64,
                              filter_max=1024,
                              phase=True,
                              **Encoder_args)
            start = 0

    Gs.print_layers()
    E.print_layers()
    D.print_layers()

    global_step = tf.Variable(start,
                              trainable=False,
                              name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate,
                                               global_step,
                                               lr_args.decay_step,
                                               lr_args.decay_rate,
                                               staircase=lr_args.stair)
    add_global = global_step.assign_add(1)
    E_opt = tflib.Optimizer(name='TrainE',
                            learning_rate=learning_rate,
                            **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=learning_rate,
                            **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            perceptual_model = PerceptualModel(
                img_size=[submit_config.image_size, submit_config.image_size],
                multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment,
                                     drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(
                    E=E_gpu,
                    G=G_gpu,
                    D=D_gpu,
                    perceptual_model=perceptual_model,
                    reals=real_gpu,
                    **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(
                    E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    #Es_update_op = Es.setup_as_moving_average_of(E, beta=E_smoothing)

    print('building testing graph...')
    fake_X_val = test(E, Gs, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess,
                                       data_dir=dataset_args.data_train,
                                       submit_config=submit_config,
                                       mode='train')
    image_batch_test = get_train_data(sess,
                                      data_dir=dataset_args.data_test,
                                      submit_config=submit_config,
                                      mode='test')

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        E.setup_weight_histograms()
        D.setup_weight_histograms()

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    print('Optimization starts!!!')
    for it in range(start, max_iters):

        feed_dict = {real_train: sess.run(image_batch_train)}
        sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict)
        sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad],
                 feed_dict)

        cur_nimg += submit_config.batch_size

        if it % 100 == 0:
            print("Iter: %06d  kimg: %-8.1f time: %-12s" %
                  (it, cur_nimg / 1000,
                   dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg

            if cur_tick % image_snapshot_ticks == 0:
                batch_images_test = sess.run(image_batch_test)
                batch_images_test = misc.adjust_dynamic_range(
                    batch_images_test.astype(np.float32), [0, 255], [-1., 1.])

                samples2 = sess.run(fake_X_val,
                                    feed_dict={real_test: batch_images_test})
                samples2 = samples2.transpose(0, 2, 3, 1)
                batch_images_test = batch_images_test.transpose(0, 2, 3, 1)
                orin_recon = np.concatenate([batch_images_test, samples2],
                                            axis=0)
                imwrite(immerge(orin_recon, 2, submit_config.batch_size_test),
                        '%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg))

            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir,
                                   'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs, NE), pkl)

    misc.save_pkl((E, G, D, Gs, NE),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
    def __init__(
            self,
            lr,
            walk_type,
            nsliders,
            loss_type,
            eps,
            N_f,
            stylegan_opts,
            is_train=False,
            submit_config=None,
            G_args={},  # 生成网络的设置。
            D_args={},  # 判别网络的设置。
            G_opt_args={},  # 生成网络优化器设置。
            D_opt_args={},  # 判别网络优化器设置。
            G_loss_args={},  # 生成损失设置。
            D_loss_args={},  # 判别损失设置。
            dataset_args={},  # 数据集设置。            
            metric_arg_list=[],  # 指标方法设置。
            tf_config={},  # tflib.init_tf()相关设置。
            G_smoothing_kimg=10.0,  # 生成器权重的运行平均值的半衰期。
            D_repeats=1,  # G每迭代一次训练判别器多少次。
            minibatch_repeats=4,  # 调整训练参数前要运行的minibatch的数量。
            reset_opt_for_new_lod=True,  # 引入新层时是否重置优化器内部状态(例如Adam时刻)?
            total_kimg=15000,  # 训练的总长度,以成千上万个真实图像为统计。
            mirror_augment=False,  # 启用镜像增强?
            drange_net=[-1, 1],  # 将图像数据馈送到网络时使用的动态范围。
            *args,
            **kwargs):

        assert (loss_type in ['l2', 'lpips']), 'unimplemented loss'
        assert (stylegan_opts.latent in ['z', 'w']), 'unknown latent space'

        self.dataset_name = stylegan_opts.dataset
        self.dataset_args = constants.net_info[stylegan_opts.dataset]
        self.latent = stylegan_opts.latent
        self.is_train = is_train
        self.walk_type = walk_type
        self.N_f = N_f  # NN num_steps
        self.eps = eps  # NN step_size
        self.Nsliders = nsliders

        if hasattr(stylegan_opts, 'truncation_psi'):
            self.psi = stylegan_opts.truncation_psi
        else:
            self.psi = 1.0

        tflib.init_tf()

        with tf.device('/gpu:0'):
            with dnnlib.util.open_url(self.dataset_args['url'],
                                      cache_dir=config.cache_dir) as f:
                # can only unpickle where dnnlib is importable, so add to syspath
                G, D, Gs = pickle.load(f)

        # input placeholders
        Nsliders = nsliders
        dim_z = self.dim_z = Gs.input_shape[1]
        z = self.z = tf.placeholder(tf.float32, shape=(None, dim_z))

        if is_train:
            # judge
            training_set = dataset.load_dataset(data_dir=config.data_dir,
                                                verbose=True,
                                                **dataset_args)

            # G.print_layers(); D.print_layers()
            # 构建计算图与优化器
            print('Building TensorFlow graph...')
            with tf.name_scope('Inputs'), tf.device('/cpu:0'):
                lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
                # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。
                lrate_in = tf.placeholder(tf.float32,
                                          name='lrate_in',
                                          shape=[])
                minibatch_in = tf.placeholder(tf.int32,
                                              name='minibatch_in',
                                              shape=[])
                minibatch_split = minibatch_in // submit_config.num_gpus
                Gs_beta = 0.5**tf.div(
                    tf.cast(minibatch_in, tf.float32), G_smoothing_kimg *
                    1000.0) if G_smoothing_kimg > 0.0 else 0.0

            G_opt = tflib.Optimizer(name='TrainG',
                                    learning_rate=lrate_in,
                                    **G_opt_args)
            #D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
            for gpu in range(submit_config.num_gpus):
                with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
                    G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
                    #D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
                    lod_assign_ops = [
                        tf.assign(G_gpu.find_var('lod'), lod_in),
                        tf.assign(D.find_var('lod'), lod_in)
                    ]
                    #reals, labels = training_set.get_minibatch_tf()
                    #reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)

                    self.steer(G_gpu, gpu_scope='GPU%d/' % gpu)

                    with tf.name_scope('G_loss'), tf.control_dependencies(
                            lod_assign_ops):
                        # G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D,
                        #     steer_z=self.latent_space_new_reshape,
                        #     steer_x=self.target,
                        #     latent=self.latent,
                        #     **G_loss_args )
                        G_loss = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_split,
                            **G_loss_args)
                    # with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                    #     D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)
                    if loss_type == 'l2':
                        self.joint_loss = G_loss + self.loss
                    else:
                        self.joint_loss = G_loss + self.loss_lpips
                    G_opt.register_gradients(tf.reduce_mean(self.joint_loss),
                                             G_gpu.trainables)
                    # D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

            train_step = tf.train.AdamOptimizer(lr).minimize(
                self.joint_loss,
                var_list=tf.trainable_variables(scope=self.scope),
                name='AdamOpter')

            G_train_op = G_opt.apply_updates()
            # D_train_op = D_opt.apply_updates()

            Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

            self.G_train_op = G_train_op
            self.Gs_update_op = Gs_update_op
            self.G = G
            self.D = D
            self.Gs = Gs
            self.G_opt = G_opt
            self.lod_in = lod_in
            self.lrate_in = lrate_in
            self.minibatch_in = minibatch_in

        else:
            self.steer(G)

            # set the scope to be 'walk'
            if loss_type == 'l2':
                train_step = tf.train.AdamOptimizer(lr).minimize(
                    self.loss,
                    var_list=tf.trainable_variables(scope=self.scope),
                    name='AdamOpter')
            elif loss_type == 'lpips':
                train_step = tf.train.AdamOptimizer(lr).minimize(
                    self.loss_lpips,
                    var_list=tf.trainable_variables(scope=self.scope),
                    name='AdamOpter')

        self.train_step = train_step
Example #10
0
def training_loop(
    submit_config,
    HP_args={},  # Options for the Hessian Penalty.
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    interp_snapshot_ticks=20,  # How often to generate interpolation visualizations in TensorBoard?
    network_snapshot_ticks=5,  # How often to export network snapshots?
    network_metric_ticks=5,  # How often to evaluate network snapshots on specified metrics?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Create a copy of dataset_args for running the metrics:
    metrics_dataset_args = deepcopy(dataset_args)
    metrics_dataset_args.shuffle_mb = 0

    print('Saving interp videos every %s ticks' % interp_snapshot_ticks)
    print('Saving network snapshot every %s ticks' % network_snapshot_ticks)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    # G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder:
    hessian_weight = tf.placeholder(tf.float32,
                                    name='hessian_weight',
                                    shape=[])

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    reg_ops = [
    ]  # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    hp_lambda=hessian_weight,
                    HP_args=HP_args,
                    gpu_ix=gpu,
                    lod_in=lod_in,
                    max_lod=training_set.resolution_log2,
                    **G_loss_args)
                if HP_args.hp_lambda > 0:
                    reg_ops += [G_hessian_penalty]
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss, mutual_info = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    gpu_ix=gpu,
                    infogan_nz=D_args.infogan_nz,
                    **D_loss_args)
                # print([name for name in D_gpu.trainables.keys()])
                # gps = [weight for name, weight in G_gpu.trainables.items()][0]
                # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0]
                # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2))
                # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2))
                # reg_ops.extend([dg, gg, dps, gps])
                if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0:
                    reg_ops += [mutual_info]
            # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero
            # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero):
            G_opt.register_gradients(
                G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables)
            D_opt.register_gradients(
                tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info,
                D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up snapshot interpolation...')
    nz = G.input_shapes[0][1]
    interp_steps = 24  # Number of frames in the visualization
    interp_batch_size = 8  # Number of gifs per row of visualization
    interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps)
    interp_labels = np.zeros(
        [interp_steps * interp_batch_size * nz, training_set.label_size],
        dtype=training_set.label_dtype)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_summary(
        build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'),
                            'samples/real'), 0)
    summary_log.add_summary(
        build_image_summary(
            os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg),
            'samples/Gs'), resume_kimg)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999:
        print('Generating initial interpolations...')
        vis_tools.make_and_save_interpolation_gifs(
            Gs,
            interp_z,
            interp_labels,
            minibatch_size=sched.minibatch // submit_config.num_gpus,
            interp_steps=interp_steps,
            interp_batch_size=interp_batch_size,
            cur_kimg=resume_kimg,
            summary_log=summary_log)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    num_G_grad_steps = 0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop():
            break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            cur_hessian_weight = get_current_hessian_penalty_loss_weight(
                HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg,
                HP_args.warmup_nimg)
            tflib.run(
                [G_train_op] + reg_ops, {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    hessian_weight: cur_hessian_weight
                })
            num_G_grad_steps += 1

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   autosummary('Progress/hessian_weight', cur_hessian_weight),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            autosummary('Progress/G_grad_steps', num_G_grad_steps)

            # Save snapshots and fake image samples (for both Gs and G):
            if cur_tick % image_snapshot_ticks == 0 or done:
                iter = (cur_nimg // 1000)
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes_inst = G.run(grid_latents,
                                        grid_labels,
                                        is_validation=True,
                                        minibatch_size=sched.minibatch //
                                        submit_config.num_gpus)
                fake_path = os.path.join(submit_config.run_dir,
                                         'fakes%06d.png' % iter)
                ifake_path = os.path.join(submit_config.run_dir,
                                          'ifakes%06d.png' % iter)
                misc.save_image_grid(grid_fakes,
                                     fake_path,
                                     drange=drange_net,
                                     grid_size=grid_size)
                misc.save_image_grid(grid_fakes_inst,
                                     ifake_path,
                                     drange=drange_net,
                                     grid_size=grid_size)
                summary_log.add_summary(
                    build_image_summary(fake_path, 'samples/Gs'), iter)
                summary_log.add_summary(
                    build_image_summary(ifake_path, 'samples/G'), iter)

            # Generate/Save Interpolation Visualizations:
            if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0:
                vis_tools.make_and_save_interpolation_gifs(
                    Gs,
                    interp_z,
                    interp_labels,
                    minibatch_size=sched.minibatch // submit_config.num_gpus,
                    interp_steps=interp_steps,
                    interp_batch_size=interp_batch_size,
                    cur_kimg=cur_nimg // 1000,
                    summary_log=summary_log)

            # Save snapshot and run metrics:
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1:
                    metrics.run(pkl,
                                dataset_args=metrics_dataset_args,
                                mirror_augment=mirror_augment,
                                num_gpus=submit_config.num_gpus,
                                tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir,
                               'network-snapshot-%06d.pkl' % total_kimg))
    summary_log.close()

    ctx.close()
Example #11
0
def training_loop(
        submit_config,
        Encoder_args={},
        D_args={},
        G_args={},
        E_opt_args={},
        D_opt_args={},
        E_loss_args=EasyDict(),
        D_loss_args={},
        lr_args=EasyDict(),
        tf_config={},
        dataset_args=EasyDict(),
        decoder_pkl=EasyDict(),
        drange_data=[0, 255],
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        mirror_augment=False,
        filter=64,  # Minimum number of feature maps in any layer.
        filter_max=512,  # Maximum number of feature maps in any layer.
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        d_scale=0.1,  # Decide whether to update discriminator.
        pretrained_D=True,  # Whether to use pre trained Discriminator.
        max_iters=150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('Input'):
        real_train = tf.placeholder(tf.float32, [
            submit_config.batch_size, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                    name='real_image_train')
        real_test = tf.placeholder(tf.float32, [
            submit_config.batch_size_test, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                   name='real_image_test')
        real_split = tf.split(real_train,
                              num_or_size_splits=submit_config.num_gpus,
                              axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            G_style_mod = tflib.Network('G_StyleMod',
                                        resolution=submit_config.image_size,
                                        label_size=0,
                                        **G_args)
            start = int(network_pkl.split('-')[-1].split('.')
                        [0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, PreD, Gs = misc.load_pkl(decoder_pkl.decoder_pkl)
            num_layers = Gs.components.synthesis.input_shape[1]
            E = tflib.Network('E_gpu0',
                              size=submit_config.image_size,
                              filter=filter,
                              filter_max=filter_max,
                              num_layers=num_layers,
                              is_training=True,
                              num_gpus=submit_config.num_gpus,
                              **Encoder_args)
            OriD = tflib.Network('D_ori',
                                 resolution=submit_config.image_size,
                                 label_size=0,
                                 **D_args)
            G_style_mod = tflib.Network('G_StyleMod',
                                        resolution=submit_config.image_size,
                                        label_size=0,
                                        **G_args)
            if pretrained_D:
                D = PreD
            else:
                D = OriD
            start = 0
        Gs_vars_pairs = {
            name: tflib.run(val)
            for name, val in Gs.components.synthesis.vars.items()
        }
        for g_name, g_val in G_style_mod.vars.items():
            tflib.set_vars({g_val: Gs_vars_pairs[g_name]})

    E.print_layers()
    Gs.print_layers()
    D.print_layers()

    global_step0 = tf.Variable(start,
                               trainable=False,
                               name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate,
                                               global_step0,
                                               lr_args.decay_step,
                                               lr_args.decay_rate,
                                               staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE',
                            learning_rate=learning_rate,
                            **E_opt_args)
    if d_scale > 0:
        D_opt = tflib.Optimizer(name='TrainD',
                                learning_rate=learning_rate,
                                **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('Building Graph on GPU %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name[:-1] + str(gpu))
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = G_style_mod if gpu == 0 else G_style_mod.clone(
                G_style_mod.name + '_shadow')
            feature_model = PerceptualModel(img_size=[
                E_loss_args.perceptual_img_size,
                E_loss_args.perceptual_img_size
            ],
                                            multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment,
                                     drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(
                    E=E_gpu,
                    G=G_gpu,
                    D=D_gpu,
                    feature_model=feature_model,
                    reals=real_gpu,
                    **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(
                    E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                if d_scale > 0:
                    D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    if d_scale > 0:
        D_train_op = D_opt.apply_updates()

    print('Building testing graph...')
    fake_X_val = test(E, G_style_mod, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess,
                                       data_dir=dataset_args.data_train,
                                       submit_config=submit_config,
                                       mode='train')
    image_batch_test = get_train_data(sess,
                                      data_dir=dataset_args.data_test,
                                      submit_config=submit_config,
                                      mode='test')

    summary_log = tf.summary.FileWriter(submit_config.run_dir)

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    print('Optimization starts!!!')
    for it in range(start, max_iters):

        batch_images = sess.run(image_batch_train)
        feed_dict = {real_train: batch_images}
        _, recon_, adv_, lr = sess.run(
            [E_train_op, E_loss_rec, E_loss_adv, learning_rate], feed_dict)
        if d_scale > 0:
            _, d_r_, d_f_, d_g_ = sess.run(
                [D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict)

        cur_nimg += submit_config.batch_size

        run_time = time.time() - start_time
        iter_time = run_time / (it - start + 1)
        eta_time = iter_time * (max_iters - it - 1)

        if it % 50 == 0:
            print(
                'Iter: %06d/%d, lr: %-.8f recon_loss: %-6.4f adv_loss: %-6.4f run_time:%-12s eta_time:%-12s'
                % (it, max_iters, lr, recon_, adv_,
                   dnnlib.util.format_time(time.time() - start_time),
                   dnnlib.util.format_time(eta_time)))
            if d_scale > 0:
                print('d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f ' %
                      (d_r_, d_f_, d_g_))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg

            if cur_tick % image_snapshot_ticks == 0:
                batch_images_test = sess.run(image_batch_test)
                batch_images_test = misc.adjust_dynamic_range(
                    batch_images_test.astype(np.float32), [0, 255], [-1., 1.])
                recon = sess.run(fake_X_val,
                                 feed_dict={real_test: batch_images_test})
                orin_recon = np.concatenate([batch_images_test, recon], axis=0)
                orin_recon = adjust_pixel_range(orin_recon)
                orin_recon = fuse_images(orin_recon,
                                         row=2,
                                         col=submit_config.batch_size_test)
                # save image results during training, first row is original images and the second row is reconstructed images
                save_image(
                    '%s/iter_%09d.png' % (submit_config.run_dir, cur_nimg),
                    orin_recon)

            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir,
                                   'network-snapshot-%09d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)

    misc.save_pkl((E, G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
def training_loop(
        submit_config,
        #---------------------------------------------------------------
        # Modified by Deng et al.
        noise_dim=32,
        weight_args={},
        train_stage_args={},
        #---------------------------------------------------------------
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        G_loss_args={},  # Options for generator loss.
        D_loss_args={},  # Options for discriminator loss.
        dataset_args={},  # Options for dataset.load_dataset().
        sched_args={},  # Options for train.TrainingSchedule.
        grid_args={},  # Options for train.setup_snapshot_image_grid().
        metric_arg_list=[],  # Options for MetricGroup.
        tf_config={},  # Options for tflib.init_tf().
        G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
        D_repeats=1,  # How many times the discriminator is trained per G iteration.
        minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
        reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
        total_kimg=15000,  # Total length of the training, measured in thousands of real images.
        mirror_augment=True,  # Enable mirror augment?
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        resume_run_id=87,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=2364,  # Snapshot index to resume training from, None = autodetect.
        resume_kimg=2364,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        resume_time=0.0,
        **_kwargs
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    PI = 3.1415927
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    # Create 3d face reconstruction block
    FaceRender = Face3D()

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            #---------------------------------------------------------------
            # Modified by Deng et al.
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              latent_size=254 + noise_dim,
                              **G_args)
            #---------------------------------------------------------------
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        resolution = tf.placeholder(tf.float32, name='resolution', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)

            #---------------------------------------------------------------
            # Modified by Deng et al.
            G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\
                G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\
                lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\
                drange_net=drange_net,lod_in=lod_in,**train_stage_args)
            #---------------------------------------------------------------

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    #---------------------------------------------------------------
    # Modified by Deng et al.
    restore_weights_and_initialize(train_stage_args)

    print('Setting up snapshot image grid...')
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)

    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3])
    grid_INPUTcoeff = z_to_lambda_mapping(grid_latents)
    grid_INPUTcoeff_w_t = tf.concat(
        [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1)
    with tf.name_scope('FaceRender'):
        grid_render_img, _, _, _ = FaceRender.Reconstruction_Block(
            grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False)
        grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2])
        grid_render_img = process_reals(grid_render_img, lod_in, False,
                                        training_set.dynamic_range, drange_net)

    grid_INPUTcoeff_, grid_renders = tflib.run(
        [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod})
    grid_noise = np.random.randn(np.prod(grid_size), 32)
    grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise],
                                             axis=1)

    grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)
    grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    #---------------------------------------------------------------

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch,
                        resolution: sched.resolution
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    resolution: sched.resolution
                })

            # print('iter')
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                #---------------------------------------------------------------
                # Modified by Deng et al.
                grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            #---------------------------------------------------------------

            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()


#----------------------------------------------------------------------------
Example #13
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device("/gpu:0"):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print("Constructing networks...")
            G = tflib.Network(
                name="G",
                num_inputs=2,  # one for latents and one for labels
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                **G_args)
            D = tflib.Network(
                name="D",
                num_inputs=int(np.log2(training_set.shape[1])) - 1 +
                1,  # +1 for labels :)
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                **D_args)
            Gs = G.clone("Gs")
    G.print_layers()
    D.print_layers()

    print("Building TensorFlow graph...")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[])
        minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = (0.5**tf.div(tf.cast(minibatch_in,
                                       tf.float32), G_smoothing_kimg *
                               1000.0) if G_smoothing_kimg > 0.0 else 0.0)

    G_opt = tflib.Optimizer(name="TrainG",
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name="TrainD",
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow")
            D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow")
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(
                reals,
                mirror_augment,
                training_set.dynamic_range,
                drange_net,
                depth=training_set.resolution_log2 - 1,
            )
            with tf.name_scope("G_loss"):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope("D_loss"):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    # Choose training parameters and configure training ops.
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)

    print("Setting up snapshot image grid...")
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_fakes = Gs.run(
        grid_latents,
        grid_labels,
        is_validation=True,
        minibatch_size=sched.minibatch_size // submit_config.num_gpus,
    )

    print("Setting up run dir...")
    fake_multi_scale_dirs = [
        os.path.join(submit_config.run_dir,
                     str(2**res) + "x" + str(2**res))
        for res in range(2, 2 + len(grid_fakes))
    ]
    misc.save_image_grid(
        grid_reals,
        os.path.join(submit_config.run_dir, "reals.png"),
        drange=training_set.dynamic_range,
        grid_size=grid_size,
    )
    misc.save_image_grids(
        grid_fakes,
        [
            os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg)
            for fake_multi_scale_dir in fake_multi_scale_dirs
        ],
        drange=drange_net,
        grid_size=grid_size,
    )
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print("Training...\n")
    ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg

    # configure the training_set to a proper minibatch size
    training_set.configure(sched.minibatch_size // submit_config.num_gpus)

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop():
            break

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op],
                    {
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch_size
                    },
                )
                cur_nimg += sched.minibatch_size
            tflib.run(
                [G_train_op],
                {
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch_size
                },
            )

        # Perform maintenance tasks once per tick.
        done = cur_nimg >= total_kimg * 1000
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f"
                % (
                    autosummary("Progress/tick", cur_tick),
                    autosummary("Progress/kimg", cur_nimg / 1000.0),
                    autosummary("Progress/minibatch", sched.minibatch_size),
                    dnnlib.util.format_time(
                        autosummary("Timing/total_sec", total_time)),
                    autosummary("Timing/sec_per_tick", tick_time),
                    autosummary("Timing/sec_per_kimg", tick_time / tick_kimg),
                    autosummary("Timing/maintenance_sec", maintenance_time),
                    autosummary("Resources/peak_gpu_mem_gb",
                                peak_gpu_mem_op.eval() / 2**30),
                ))
            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(
                    grid_latents,
                    grid_labels,
                    is_validation=True,
                    minibatch_size=sched.minibatch_size //
                    submit_config.num_gpus,
                )
                misc.save_image_grids(
                    grid_fakes,
                    [
                        os.path.join(fake_multi_scale_dir, "fakes%06d.png" %
                                     (cur_nimg // 1000))
                        for fake_multi_scale_dir in fake_multi_scale_dirs
                    ],
                    drange=drange_net,
                    grid_size=grid_size,
                )
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    "network-snapshot-%06d.pkl" % (cur_nimg // 1000),
                )
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(
                    pkl,
                    run_dir=submit_config.run_dir,
                    num_gpus=submit_config.num_gpus,
                    tf_config=tf_config,
                )

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, "network-final.pkl"))
    summary_log.close()

    ctx.close()
Example #14
0
def training_loop(
    G_args                  = {},       # Options for generator network.
    D_args                  = {},       # Options for discriminator network.
    G_opt_args              = {},       # Options for generator optimizer.
    D_opt_args              = {},       # Options for discriminator optimizer.
    G_loss_args             = {},       # Options for generator loss.
    D_loss_args             = {},       # Options for discriminator loss.
    dataset_args            = {},       # Options for dataset.load_dataset().
    sched_args              = {},       # Options for train.TrainingSchedule.
    grid_args               = {},       # Options for train.setup_snapshot_image_grid().
    metric_arg_list         = [],       # Options for MetricGroup.
    tf_config               = {},       # Options for tflib.init_tf().
    data_dir                = None,     # Directory to load datasets from.
    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.
    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.
    lazy_regularization     = True,     # Perform regularization as a separate training step?
    G_reg_interval          = 4,        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval          = 16,       # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    drange_net              = [0,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_pkl              = None,     # Network pickle to resume training from, None = train from scratch.
    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0,      # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets    = False):   # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
    misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    # Construct or load networks.
    training_set.configure(minibatch_size=sched_args.batch_size)
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
            start = 0
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs)
            else: G = rG; D = rD; Gs = rGs
            start = int(resume_pkl.split('-')[-1].split('.')[0]) // sched_args.batch_size

    # Print layers and generate initial image snapshot.
    G.print_layers(); D.print_layers()
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size)
    misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size)

    global_step = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(sched_args.lr, global_step, sched_args.decay_step,
                                               sched_args.decay_rate, staircase=sched_args.stair)
    add_global = global_step.assign_add(1)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)
    G_opt = tflib.Optimizer(name='TrainG', learning_rate=learning_rate, **G_opt_args)

    for gpu in range(num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            with tf.name_scope('DataFetch'):
                reals_read, labels_read = training_set.get_minibatch_tf()

                reals_read, labels_read = process_reals(reals_read, labels_read, mirror_augment,
                                                        training_set.dynamic_range, drange_net)

            with tf.name_scope('Loss'), tf.control_dependencies(None):
                loss, reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set,
                                                          minibatch_size=sched_args.batch_size, reals=reals_read,
                                                          labels=labels_read, **D_loss_args)
            with tf.control_dependencies([add_global]):
                G_opt.register_gradients(loss, G_gpu.trainables)
                D_opt.register_gradients(loss, D_gpu.trainables)

    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        loss_, _, _, lr_ = tflib.run([loss, G_train_op, D_train_op, learning_rate])
        cur_nimg += sched_args.batch_size * num_gpus
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f '
                'sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f loss %-8.1f lr %-2.5f' % (
                    autosummary('Progress/tick', cur_tick),
                    autosummary('Progress/kimg', cur_nimg / 1000.0),
                    autosummary('Progress/minibatch', sched_args.batch_size),
                    dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                    autosummary('Timing/sec_per_tick', tick_time),
                    autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                    autosummary('Timing/maintenance_sec', maintenance_time),
                    autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2 ** 30),
                    autosummary('loss', loss_),
                    autosummary('lr', lr_)))

            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size)
                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net, grid_size=grid_size)
            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus, tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

            # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

        # All done.
    summary_log.close()
    training_set.close()
Example #15
0
def training_loop(
                  submit_config,
                  Encoder_args            = {},
                  E_opt_args              = {},
                  D_opt_args              = {},
                  E_loss_args             = EasyDict(),
                  D_loss_args             = {},
                  lr_args                 = EasyDict(),
                  tf_config               = {},
                  dataset_args            = EasyDict(),
                  decoder_pkl             = EasyDict(),
                  drange_data             = [0, 255],
                  drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
                  mirror_augment          = False,
                  resume_run_id           = config.ENCODER_PICKLE_DIR,     # Run ID or network pkl to resume training from, None = start from scratch.
                  resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
                  image_snapshot_ticks    = 1,        # How often to export image snapshots?
                  network_snapshot_ticks  = 4,       # How often to export network snapshots?
                  max_iters               = 150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        real_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='real_image_train')
        real_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='real_image_test')
        real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, D, Gs = misc.load_pkl(decoder_pkl.decoder_pkl)
            num_layers = Gs.components.synthesis.input_shape[1]
            E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args)
            start = 0

    E.print_layers(); Gs.print_layers(); D.print_layers()

    global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step,
                                               lr_args.decay_rate, staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('building testing graph...')
    fake_X_val = test(E, Gs, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train')
    image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test')

    summary_log = tf.summary.FileWriter(config.GDRIVE_PATH)

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    init_pascal = tf.initialize_variables(
        [global_step0],
        name='init_pascal'
    )
    sess.run(init_pascal)
    
    print('Optimization starts!!!')
    
    
    for it in range(start, max_iters):

        batch_images = sess.run(image_batch_train)
        feed_dict_1 = {real_train: batch_images}
        _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1)
        _, d_r_, d_f_, d_g_ = sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1)

        cur_nimg += submit_config.batch_size

        if it % 50 == 0:
            print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % (
                it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)
            
            
            
            
        if it % 500 == 0:
            batch_images_test = sess.run(image_batch_test)
            batch_images_test = misc.adjust_dynamic_range(batch_images_test.astype(np.float32), [0, 255], [-1., 1.])
            samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test})
            orin_recon = np.concatenate([batch_images_test, samples2], axis=0)
            orin_recon = adjust_pixel_range(orin_recon)
            orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test)
            # save image results during training, first row is original images and the second row is reconstructed images
            save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), orin_recon)

            # save image to gdrive
            img_path = os.path.join(config.GDRIVE_PATH, 'images', ('iter_%08d.png' % (cur_nimg)))
            save_image(img_path, orin_recon)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg



            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)
                
                # save network snapshot to gdrive
                pkl_drive = os.path.join(config.GDRIVE_PATH, 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl_drive)

    misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
Example #16
0
    # Define Variables 
    seed_latent = tf.Variable(np.zeros((1, 512)), dtype=tf.float32)
    target = tf.Variable(np.zeros((1, 3, 1024, 1024)), dtype=np.float32)


    # Define network 
    Gs = load_network()
    fea_ext_name_to_var_name = {'inception_v3': 'InceptionV3'}
    feature_extractor = get_feature_extractor(args)
    noise_loss, fake_images_out = get_nosie_output(args, Gs, seed_latent, None, target, feature_extractor)
    

    # Apply optimizers
    lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
    noise_opt = tflib.Optimizer(
        name='Noise', learning_rate=lrate_in, beta1=0.0, beta2=0.99, epsilon=1e-8)
    noise_opt.register_gradients(tf.reduce_mean(
        noise_loss), OrderedDict([('latents', seed_latent)]))
    noise_update_op = noise_opt.apply_updates()

    # Get TF Session
    sess = tf.get_default_session()
    
    # Initialize variables
    uninstizalized_vars = [seed_latent, target]
    init_op = tf.variables_initializer(uninstizalized_vars)
    sess.run([init_op])
    if args.feature_extractor is not None and args.feature_extractor != 'D':
        extractor_variables = slim.get_variables_to_restore()
        extractor_variables = [
            x for x in extractor_variables if fea_ext_name_to_var_name[args.feature_extractor] in x.name]
Example #17
0
def training_auto_loop(
    Enc_args={},  # Options for encoder network.
    Dec_args={},  # Options for decoder network.
    opt_args={},  # Options for encoder optimizer.
    loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False
):  # Include weight histograms in the tfevents file?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, _ = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        print('Constructing networks...')
        Enc = tflib.Network('Encoder',
                            resolution=training_set.shape[1],
                            out_channels=16,
                            **Enc_args)
        Dec = tflib.Network('Decoder',
                            resolution=training_set.shape[1] // 4,
                            in_channels=16,
                            **Dec_args)

    # Print layers and generate initial image snapshot.
    Enc.print_layers()
    Dec.print_layers()
    sched = sched_args
    sched.tick_kimg = 4
    grid_codes = Enc.run((grid_reals / 127.5) - 1.0,
                         minibatch_size=sched.minibatch_gpu)
    grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    opt_args = dict(opt_args)
    opt_args['minibatch_multiplier'] = minibatch_multiplier
    opt_args['learning_rate'] = lrate_in
    opt = tflib.Optimizer(name='TrainAuto', **opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of Enc and Dec.
            Enc_gpu = Enc if gpu == 0 else Enc.clone(Enc.name + '_shadow')
            Dec_gpu = Dec if gpu == 0 else Dec.clone(Dec.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, _ = process_reals(reals_write, labels_write, 0.0,
                                               False,
                                               training_set.dynamic_range,
                                               drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                reals_read = reals_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            with tf.name_scope('loss'):
                loss = dnnlib.util.call_func_by_name(Enc=Enc_gpu,
                                                     Dec=Dec_gpu,
                                                     opt=opt,
                                                     reals=reals_read,
                                                     **loss_args)

            # Register gradients.
            opt.register_gradients(
                tf.reduce_mean(loss),
                list(Enc_gpu.trainables.values()) +
                list(Dec_gpu.trainables.values()))

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    train_op = opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        Enc.setup_weight_histograms()
        Dec.setup_weight_histograms()

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run(data_fetch_op, feed_dict)
                tflib.run(train_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start()

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_codes = Enc.run((grid_reals / 127.5) - 1.0,
                                     minibatch_size=sched.minibatch_gpu)
                grid_fakes = Dec.run(grid_codes,
                                     minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((Enc, Dec), pkl)

            # Update summaries and RunContext.
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0.0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((Enc, Dec), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Example #18
0
def training_loop(
    classifier_args={},  # Options for generator network.
    classifier_opt_args={},  # Options for generator optimizer.
    classifier_loss_args={},
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    network_snapshot_ticks=5,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False):

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        shuffle_mb=2 * 4096,
                                        **dataset_args)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        print('Constructing networks...')
        classifier = tflib.Network('classifier',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   label_size=training_set.label_size,
                                   **classifier_args)

    classifier.print_layers()

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    classifier_opt_args = dict(classifier_opt_args)

    classifier_opt_args['minibatch_multiplier'] = minibatch_multiplier
    classifier_opt_args['learning_rate'] = lrate_in

    classifier_opt = tflib.Optimizer(name='TrainClassifier',
                                     **classifier_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('gpu%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            classifier_gpu = classifier if gpu == 0 else classifier.clone(
                classifier.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=0, **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros(
                                             [sched.minibatch_gpu, 127]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            with tf.name_scope('classifier_loss'):
                classifier_loss, label = dnnlib.util.call_func_by_name(
                    classifier=classifier_gpu,
                    images=reals_read,
                    labels=labels_read,
                    **classifier_loss_args)

            classifier_opt.register_gradients(tf.reduce_mean(classifier_loss),
                                              classifier_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    classifier_train_op = classifier_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())

    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', cur_epoch=0, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([classifier_train_op, data_fetch_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    classifier_loss_out, label_out, _ = tflib.run(
                        [classifier_loss, label, classifier_train_op],
                        feed_dict)
                    print_output = False
                    if print_output:
                        print('label')
                        print(np.round(label_out, 2))
                        print('loss')
                        print(np.round(classifier_loss_out, 2))
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start()

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl(classifier, pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl(classifier, dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Example #19
0
def training_loop_hd(
    I_args={},  # Options for generator network.
    M_args={},  # Options for discriminator network.
    I_info_args={},  # Options for class network.
    I_opt_args={},  # Options for generator optimizer.
    I_loss_args={},  # Options for generator loss.
    resume_G_pkl=None,  # G network pickle to help training.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    I_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    I_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,  # Construct new networks according to I_args and M_args before resuming training?
    traversal_grid=False,  # Used for disentangled representation learning.
    n_discrete=0,  # Number of discrete latents in model.
    n_continuous=10,  # Number of continuous latents in model.
    n_samples_per=4,  # Number of samples for each line in traversal.
    use_hd_with_cls=False,  # If use info_loss.
    resolution_manual=1024,  # Resolution of generated images.
    use_level_training=False,  # If use level training (hierarchical optimization strategy).
    level_I_kimg=1000,  # Number of kimg of tick for I_level training.
    use_std_in_m=False,  # If output prior std in M net.
    prior_latent_size=512,  # Prior latent size.
    use_hyperplane=False,  # If use hyperplane model.
    pretrained_type='with_stylegan2'):  # Pretrained type for G.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            I = tflib.Network('I',
                              num_channels=training_set.shape[0],
                              resolution=resolution_manual,
                              label_size=training_set.label_size,
                              **I_args)
            M = tflib.Network('M',
                              num_channels=training_set.shape[0],
                              resolution=resolution_manual,
                              label_size=training_set.label_size,
                              **M_args)
            Is = I.clone('Is')
            if use_hd_with_cls:
                I_info = tflib.Network('I_info',
                                       num_channels=training_set.shape[0],
                                       resolution=resolution_manual,
                                       label_size=training_set.label_size,
                                       **I_info_args)
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if use_hd_with_cls:
                rI, rM, rIs, rI_info = misc.load_pkl(resume_pkl)
            else:
                rI, rM, rIs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                I.copy_vars_from(rI)
                M.copy_vars_from(rM)
                Is.copy_vars_from(rIs)
                if use_hd_with_cls:
                    I_info.copy_vars_from(rI_info)
            else:
                I = rI
                M = rM
                Is = rIs
                if use_hd_with_cls:
                    I_info = rI_info

        print('Loading generator from "%s"...' % resume_G_pkl)
        if pretrained_type == 'with_stylegan2':
            rG, rD, rGs = misc.load_pkl(resume_G_pkl)
            G = rG
            D = rD
            Gs = rGs
        elif pretrained_type == 'with_cascadeVAE':
            rG = misc.load_pkl(resume_G_pkl)
            G = rG
            Gs = rG

    # Print layers and generate initial image snapshot.
    I.print_layers()
    M.print_layers()
    # pdb.set_trace()
    training_set_resolution_log2 = int(np.log2(resolution_manual))
    sched = training_schedule(
        cur_nimg=total_kimg * 1000,
        training_set_resolution_log2=training_set_resolution_log2,
        **sched_args)

    if not use_hyperplane:
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels)
        print('grid_size:', grid_size)
        print('grid_latents.shape:', grid_latents.shape)
        print('grid_labels.shape:', grid_labels.shape)
        if resolution_manual >= 256:
            grid_size = (grid_size[0], grid_size[1] // 5)
            grid_latents = grid_latents[:grid_latents.shape[0] // 5]
            grid_labels = grid_labels[:grid_labels.shape[0] // 5]
        prior_traj_latents = M.run(grid_latents,
                                   is_validation=True,
                                   minibatch_size=sched.minibatch_gpu)
        if use_std_in_m:
            prior_traj_latents = prior_traj_latents[:, :prior_latent_size]
    else:
        grid_size = (n_samples_per, n_continuous)
        grid_labels = np.tile(grid_labels[:1],
                              (n_continuous * n_samples_per, 1))
        latent_dirs = get_latent_dirs(n_continuous)
        prior_traj_latents = get_prior_traj_by_dirs(latent_dirs, M,
                                                    n_samples_per,
                                                    prior_latent_size,
                                                    grid_labels, sched)
        if resolution_manual >= 256:
            grid_size = (grid_size[0], grid_size[1] // 5)
            prior_traj_latents = prior_traj_latents[:prior_traj_latents.
                                                    shape[0] // 5]
            grid_labels = grid_labels[:grid_labels.shape[0] // 5]
    print('prior_traj_latents.shape:', prior_traj_latents.shape)
    # pdb.set_trace()
    prior_traj_latents_show = np.reshape(
        prior_traj_latents, [-1, n_samples_per, prior_latent_size])
    print_traj(prior_traj_latents_show)
    grid_fakes = Gs.run(prior_traj_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu,
                        randomize_noise=True,
                        normalize_latents=False)
    grid_fakes = add_outline(grid_fakes, width=1)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)
    if (n_continuous == 2) and (n_discrete == 0):
        n_per_line = 20
        ex_latent_value = 3
        prior_grid_latents, prior_grid_labels = get_2d_grid_latents(
            low=-ex_latent_value,
            high=ex_latent_value,
            n_per_line=n_per_line,
            grid_labels=grid_labels)
        grid_showing_fakes = Gs.run(prior_grid_latents,
                                    prior_grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu,
                                    randomize_noise=True,
                                    normalize_latents=False)
        grid_showing_fakes = add_outline(grid_showing_fakes, width=1)
        misc.save_image_grid(
            grid_showing_fakes,
            dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'),
            drange=drange_net,
            grid_size=[n_per_line, n_per_line])
        img_to_draw = Image.open(
            dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'))
        img_to_draw = img_to_draw.convert('RGB')
        img_to_draw = draw_traj_on_prior_grid(img_to_draw,
                                              prior_traj_latents_show,
                                              ex_latent_value, n_per_line)
        img_to_draw.save(
            dnnlib.make_run_dir_path('fakes_init_2d_prior_grid_drawn.png'))

    if use_level_training:
        ending_level = training_set_resolution_log2 - 1
    else:
        ending_level = 1

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Is_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), I_smoothing_kimg *
                              1000.0) if I_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    I_opt_args = dict(I_opt_args)
    for args, reg_interval in [(I_opt_args, I_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio

    I_opts = []
    I_reg_opts = []
    for n_level in range(ending_level):
        I_opts.append(tflib.Optimizer(name='TrainI_%d' % n_level,
                                      **I_opt_args))
        I_reg_opts.append(
            tflib.Optimizer(name='RegI_%d' % n_level,
                            share=I_opts[-1],
                            **I_opt_args))

    # Build training graph for each GPU.
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of I and M.
            I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow')
            M_gpu = M if gpu == 0 else M.clone(M.name + '_shadow')
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if use_hd_with_cls:
                I_info_gpu = I_info if gpu == 0 else I_info.clone(I_info.name +
                                                                  '_shadow')

            # Evaluate loss functions.
            lod_assign_ops = []
            I_losses = []
            I_regs = []
            if 'lod' in I_gpu.vars:
                lod_assign_ops += [tf.assign(I_gpu.vars['lod'], lod_in)]
            if 'lod' in M_gpu.vars:
                lod_assign_ops += [tf.assign(M_gpu.vars['lod'], lod_in)]
            for n_level in range(ending_level):
                with tf.control_dependencies(lod_assign_ops):
                    with tf.name_scope('I_loss_%d' % n_level):
                        if use_hd_with_cls:
                            I_loss, I_reg = dnnlib.util.call_func_by_name(
                                I=I_gpu,
                                M=M_gpu,
                                G=G_gpu,
                                I_info=I_info_gpu,
                                opt=I_opts[n_level],
                                training_set=training_set,
                                minibatch_size=minibatch_gpu_in,
                                **I_loss_args)
                        else:
                            I_loss, I_reg = dnnlib.util.call_func_by_name(
                                I=I_gpu,
                                M=M_gpu,
                                G=G_gpu,
                                opt=I_opts[n_level],
                                n_levels=(n_level +
                                          1) if use_level_training else None,
                                training_set=training_set,
                                minibatch_size=minibatch_gpu_in,
                                **I_loss_args)
                        I_losses.append(I_loss)
                        I_regs.append(I_reg)

                # Register gradients.
                if not lazy_regularization:
                    if I_regs[n_level] is not None:
                        I_losses[n_level] += I_regs[n_level]
                else:
                    if I_regs[n_level] is not None:
                        I_reg_opts[n_level].register_gradients(
                            tf.reduce_mean(I_regs[n_level] * I_reg_interval),
                            I_gpu.trainables)

                if use_hd_with_cls:
                    MIIinfo_gpu_trainables = collections.OrderedDict(
                        list(M_gpu.trainables.items()) +
                        list(I_gpu.trainables.items()) +
                        list(I_info_gpu.trainables.items()))
                    I_opts[n_level].register_gradients(
                        tf.reduce_mean(I_losses[n_level]),
                        MIIinfo_gpu_trainables)
                else:
                    MI_gpu_trainables = collections.OrderedDict(
                        list(M_gpu.trainables.items()) +
                        list(I_gpu.trainables.items()))
                    I_opts[n_level].register_gradients(
                        tf.reduce_mean(I_losses[n_level]), MI_gpu_trainables)

    # Setup training ops.
    I_train_ops = []
    I_reg_ops = []
    for n_level in range(ending_level):
        I_train_ops.append(I_opts[n_level].apply_updates())
        I_reg_ops.append(I_reg_opts[n_level].apply_updates(allow_no_op=True))
    Is_update_op = Is.setup_as_moving_average_of(I, beta=Is_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        I.setup_weight_histograms()
        M.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        n_level = 0 if not use_level_training else min(
            cur_nimg // (level_I_kimg * 1000), training_set_resolution_log2 -
            2)
        # Choose training parameters and configure training ops.
        sched = training_schedule(
            cur_nimg=cur_nimg,
            training_set_resolution_log2=training_set_resolution_log2,
            **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                I_opts[n_level].reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.I_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_I_reg = (lazy_regularization
                         and running_mb_counter % I_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run(I_train_ops[n_level], feed_dict)
                if run_I_reg:
                    tflib.run(I_reg_ops[n_level], feed_dict)
                tflib.run([Is_update_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(I_train_ops[n_level], feed_dict)
                if run_I_reg:
                    for _round in rounds:
                        tflib.run(I_reg_ops[n_level], feed_dict)
                tflib.run(Is_update_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 100 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):

                if not use_hyperplane:
                    prior_traj_latents = M.run(
                        grid_latents,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
                    if use_std_in_m:
                        prior_traj_latents = prior_traj_latents[:, :
                                                                prior_latent_size]
                else:
                    prior_traj_latents = get_prior_traj_by_dirs(
                        latent_dirs, M, n_samples_per, prior_latent_size,
                        grid_labels, sched)
                    if resolution_manual >= 256:
                        prior_traj_latents = prior_traj_latents[:
                                                                prior_traj_latents
                                                                .shape[0] // 5]

                prior_traj_latents_show = np.reshape(
                    prior_traj_latents, [-1, n_samples_per, prior_latent_size])
                print_traj(prior_traj_latents_show)
                grid_fakes = Gs.run(prior_traj_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu,
                                    randomize_noise=True,
                                    normalize_latents=False)
                grid_fakes = add_outline(grid_fakes, width=1)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                if (n_continuous == 2) and (n_discrete == 0):
                    n_per_line = 20
                    ex_latent_value = 3
                    prior_grid_latents, prior_grid_labels = get_2d_grid_latents(
                        low=-ex_latent_value,
                        high=ex_latent_value,
                        n_per_line=n_per_line,
                        grid_labels=grid_labels)
                    grid_showing_fakes = Gs.run(
                        prior_grid_latents,
                        prior_grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu,
                        randomize_noise=True,
                        normalize_latents=False)
                    grid_showing_fakes = add_outline(grid_showing_fakes,
                                                     width=1)
                    misc.save_image_grid(grid_showing_fakes,
                                         dnnlib.make_run_dir_path(
                                             'fakes_2d_prior_grid%06d.png' %
                                             (cur_nimg // 1000)),
                                         drange=drange_net,
                                         grid_size=[n_per_line, n_per_line])
                    img_to_draw = Image.open(
                        dnnlib.make_run_dir_path(
                            'fakes_2d_prior_grid%06d.png' %
                            (cur_nimg // 1000)))
                    img_to_draw = img_to_draw.convert('RGB')
                    img_to_draw = draw_traj_on_prior_grid(
                        img_to_draw, prior_traj_latents_show, ex_latent_value,
                        n_per_line)
                    img_to_draw.save(
                        dnnlib.make_run_dir_path(
                            'fakes_2d_prior_grid_drawn%06d.png' %
                            (cur_nimg // 1000)))

            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((I, M, Is), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((I, M, Is), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
Example #21
0
def training_loop(
                  submit_config,
                  Encoder_args            = {},
                  E_opt_args              = {},
                  D_opt_args              = {},
                  E_loss_args             = EasyDict(),
                  D_loss_args             = {},
                  lr_args                 = EasyDict(),
                  tf_config               = {},
                  dataset_args            = EasyDict(),
                  decoder_pkl             = EasyDict(),
                  inversion_pkl           = EasyDict(),
                  drange_data             = [0, 255],
                  drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
                  mirror_augment          = False,
                  resume_run_id           = config.ENCODER_PICKLE_DIR,     # Run ID or network pkl to resume training from, None = start from scratch.
                  resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
                  image_snapshot_ticks    = 1,        # How often to export image snapshots?
                  network_snapshot_ticks  = 4,       # How often to export network snapshots?
                  max_iters               = 150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        placeholder_real_portraits_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_train')
        placeholder_real_landmarks_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_train')
        placeholder_real_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_train')
        placeholder_landmarks_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_landmarks_shuffled_train')


        placeholder_real_portraits_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_test')
        placeholder_real_landmarks_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_test')
        placeholder_real_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_test')
        placeholder_real_landmarks_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_shuffled_test')

        real_split_landmarks = tf.split(placeholder_real_landmarks_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_portraits = tf.split(placeholder_real_portraits_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_shuffled = tf.split(placeholder_real_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_lm_shuffled = tf.split(placeholder_landmarks_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        
        placeholder_training_flag = tf.placeholder(tf.string, name='placeholder_training_flag')

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, _, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) # don't use pre-trained discriminator!
            num_layers = Gs.components.synthesis.input_shape[1]

            # here we add a new discriminator!
            D = tflib.Network('D',  # name of the network how we call it
                              num_channels=3, resolution=128, label_size=0,  #some needed for this build function
                              func_name="training.networks_stylegan.D_basic") # function of that network. more was not passed in d_args!
                              # input is not passed here (just construction - note that we do not call the actual function!). Instead, network will inspect build function and require it for the get_output_for function.
            print("Created new Discriminator!")

            E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args)
            start = 0
        Inv, _, _, _ = misc.load_pkl(inversion_pkl.inversion_pkl)

    E.print_layers(); Gs.print_layers(); D.print_layers()

    global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step,
                                               lr_args.decay_rate, staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            Inv_gpu = Inv if gpu == 0 else Inv.clone(Inv.name + '_shadow')
            perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False)
            real_portraits_gpu = process_reals(real_split_portraits[gpu], mirror_augment, drange_data, drange_net)
            shuffled_portraits_gpu = process_reals(real_split_shuffled[gpu], mirror_augment, drange_data, drange_net)
            real_landmarks_gpu = process_reals(real_split_landmarks[gpu], mirror_augment, drange_data, drange_net)
            shuffled_landmarks_gpu = process_reals(real_split_lm_shuffled[gpu], mirror_augment, drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, perceptual_model=perceptual_model, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, shuffled_landmarks=shuffled_landmarks_gpu, training_flag=placeholder_training_flag, **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, training_flag=placeholder_training_flag, **D_loss_args) # change signature in ...
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('building testing graph...')
    fake_X_val = test(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config)
    inv_X_val = test_inversion(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config)
    
    #sampled_portraits_val = sample_random_portraits(Gs, submit_config.batch_size)
    #sampled_portraits_val_test = sample_random_portraits(Gs, submit_config.batch_size_test)

    sess = tf.get_default_session()

    print('Getting training data...')
    # x_batch is a batch of (2, ..., ..., ...) records!
    stack_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train')
    stack_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test')
    
    stack_batch_train_secondary = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train_secondary')
    stack_batch_test_secondary = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test_secondary')

    summary_log = tf.summary.FileWriter(config.getGdrivePath())

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    init_fix = tf.initialize_variables(
        [global_step0],
        name='init_fix'
    )
    sess.run(init_fix)
    
    print('Optimization starts!!!')
    
    
    # here is the actual training loop: all iterations
    for it in range(start, max_iters):
        batch_stacks = sess.run(stack_batch_train)
        batch_portraits = batch_stacks[:,0,:,:,:]
        batch_landmarks = batch_stacks[:,1,:,:,:]
        
        batch_stacks_secondary = sess.run(stack_batch_train_secondary)
        batch_shuffled = batch_stacks_secondary[:,0,:,:,:]
        batch_lm_shuffled = batch_stacks_secondary[:,1,:,:,:]
        
        
        training_flag = "pose"
        
        feed_dict_1 = {placeholder_real_portraits_train: batch_portraits, placeholder_real_landmarks_train: batch_landmarks, placeholder_real_shuffled_train:batch_shuffled, placeholder_landmarks_shuffled_train:batch_lm_shuffled, placeholder_training_flag: training_flag}
        # here we query these encoder- and discriminator losses. as input we provide: batch_stacks = batch of images + landmarks.
        _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1)
        _, d_r_, d_f_, d_g_= sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1)

        cur_nimg += submit_config.batch_size

        if it % 50 == 0:
            print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % (
                it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)
            
            
            
            
        if it % 500 == 0:
            batch_stacks_test = sess.run(stack_batch_test)
            batch_portraits_test = batch_stacks_test[:,0,:,:,:]
            batch_landmarks_test = batch_stacks_test[:,1,:,:,:]
            
            batch_stacks_test_secondary = sess.run(stack_batch_test_secondary)
            batch_shuffled_test = batch_stacks_test_secondary[:,0,:,:,:]
            batch_shuffled_lm_test = batch_stacks_test_secondary[:,1,:,:,:]
            
            
            batch_portraits_test = misc.adjust_dynamic_range(batch_portraits_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_landmarks_test = misc.adjust_dynamic_range(batch_landmarks_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_shuffled_test = misc.adjust_dynamic_range(batch_shuffled_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_shuffled_lm_test = misc.adjust_dynamic_range(batch_shuffled_lm_test.astype(np.float32), [0, 255], [-1., 1.])

            # first: input + target landmarks = manipulated image
            samples_manipulated = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_shuffled_lm_test})

            # 2nd: manipulated + original landmarks
            samples_reconstructed = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: samples_manipulated, placeholder_real_landmarks_test: batch_landmarks_test})

            # also: show direct reconstruction
            samples_direct_rec = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test})

            # show results of the inverison
            portraits_inverted = sess.run(inv_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test})

            # show: original portrait, original landmark, diret reconstruction, fake landmark, manipulated, rec.
            debug_img = np.concatenate([
                    batch_landmarks_test, # original landmarks
                    batch_portraits_test, # original portraits,
                    samples_direct_rec, # direct
                    batch_shuffled_lm_test, # shuffled landmarks
                    samples_manipulated, # manipulated images
                    samples_reconstructed,
                    portraits_inverted# cycle reconstructed images
                ], axis=0)

            debug_img = adjust_pixel_range(debug_img)
            debug_img = fuse_images(debug_img, row=6, col=submit_config.batch_size_test)
            # save image results during training, first row is original images and the second row is reconstructed images
            save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), debug_img)

            # save image to gdrive
            img_path = os.path.join(config.getGdrivePath(), 'images', ('iter_%08d.png' % (cur_nimg)))
            save_image(img_path, debug_img)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg



            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)
                
                # save network snapshot to gdrive
                pkl_drive = os.path.join(config.getGdrivePath(), 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl_drive)

    misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
Example #22
0
def training_loop(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    AE_opt_args=None,  # Options for autoencoder optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    AE_loss_args=None,  # Options for autoencoder loss.
    dataset_args={},  # Options for dataset.load_dataset().
    dataset_args_eval={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    train_data_dir=None,  # Directory to load datasets from.
    eval_data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=True,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,
    resume_with_own_vars=False
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    print("Loading train set from %s..." % dataset_args.tfrecord_dir)
    training_set = dataset.load_dataset(
        data_dir=dnnlib.convert_path(train_data_dir),
        verbose=True,
        **dataset_args)
    print("Loading eval set from %s..." % dataset_args_eval.tfrecord_dir)
    eval_set = dataset.load_dataset(
        data_dir=dnnlib.convert_path(eval_data_dir),
        verbose=True,
        **dataset_args_eval)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    # Freeze Discriminator
    if D_args['freeze']:
        num_layers = np.log2(training_set.resolution) - 1
        layers = int(np.round(num_layers * 3. / 8.))
        scope = ['Output', 'scores_out']
        for layer in range(layers):
            scope += ['.*%d' % 2**layer]
            if 'train_scope' in D_args:
                scope[-1] += '.*%d' % D_args['train_scope']
        D_args['train_scope'] = scope

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is '' or resume_with_new_nets or resume_with_own_vars:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not '':
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs

    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    # SVD stuff
    if 'syn_svd' in G_args or 'map_svd' in G_args:
        # Run graph to calculate SVD
        grid_latents_smol = grid_latents[:1]
        rho = np.array([1])
        grid_fakes = G.run(grid_latents_smol,
                           grid_labels,
                           rho,
                           is_validation=True)
        grid_fakes = Gs.run(grid_latents_smol,
                            grid_labels,
                            rho,
                            is_validation=True)
        load_d_fake = D.run(grid_reals[:1], rho, is_validation=True)
        with tf.device('/gpu:0'):
            # Create SVD-decomposed graph
            rG, rD, rGs = G, D, Gs
            G_lambda_mask = {
                var: np.ones(G.vars[var].shape[-1])
                for var in G.vars if 'SVD/s' in var
            }
            D_lambda_mask = {
                'D/' + var: np.ones(D.vars[var].shape[-1])
                for var in D.vars if 'SVD/s' in var
            }
            G_reduce_dims = {
                var: (0, int(Gs.vars[var].shape[-1]))
                for var in Gs.vars if 'SVD/s' in var
            }
            G_args['lambda_mask'] = G_lambda_mask
            G_args['reduce_dims'] = G_reduce_dims
            D_args['lambda_mask'] = D_lambda_mask

            # Create graph with no SVD operations
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=rG.input_shapes[1][1],
                              factorized=True,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=rD.input_shapes[1][1],
                              factorized=True,
                              **D_args)
            Gs = G.clone('Gs')

            grid_fakes = G.run(grid_latents_smol,
                               grid_labels,
                               rho,
                               is_validation=True,
                               minibatch_size=1)
            grid_fakes = Gs.run(grid_latents_smol,
                                grid_labels,
                                rho,
                                is_validation=True,
                                minibatch_size=1)

            G.copy_vars_from(rG)
            D.copy_vars_from(rD)
            Gs.copy_vars_from(rGs)

    # Reduce per-gpu minibatch size to fit in 16GB GPU memory
    if grid_reals.shape[2] >= 1024:
        sched_args.minibatch_gpu_base = 2
    print('Batch size', sched_args.minibatch_gpu_base)

    # Generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    rho = np.array([1])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        rho,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)
    if resume_pkl is not '':
        load_d_real = rD.run(grid_reals[:1], rho, is_validation=True)
        load_d_fake = rD.run(grid_fakes[:1], rho, is_validation=True)
        d_fake = D.run(grid_fakes[:1], rho, is_validation=True)
        d_real = D.run(grid_reals[:1], rho, is_validation=True)
        print('Factorized fake', d_fake, 'loaded fake', load_d_fake,
              'factorized real', d_real, 'loaded real', load_d_real)
        print('(should match)')
    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)
    if AE_opt_args is not None:
        AE_opt_args = dict(AE_opt_args)
        AE_opt_args['minibatch_multiplier'] = minibatch_multiplier
        AE_opt_args['learning_rate'] = lrate_in
        AE_opt = tflib.Optimizer(name='TrainAE', **AE_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    if G_loss_args['func_name'] == 'training.loss.G_l1':
                        G_loss_args['reals'] = reals_read
                    else:
                        G_loss, G_reg = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0

    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            ae_iter_mul = 10
            ae_rounds = range(0, sched.minibatch_size,
                              sched.minibatch_gpu * num_gpus * ae_iter_mul)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    _g_loss, _ = tflib.run([G_loss, G_train_op], feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                print('g loss', _g_loss)
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            if network_snapshot_ticks is not None and cur_tick % network_snapshot_ticks == 0 or done:
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(eval_data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config,
                            rho=rho)
            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
    eval_set.close()
def training_loop(
    submit_config,
    G_args                  = {},       # 生成网络的设置。
    D_args                  = {},       # 判别网络的设置。
    G_opt_args              = {},       # 生成网络优化器设置。
    D_opt_args              = {},       # 判别网络优化器设置。
    G_loss_args             = {},       # 生成损失设置。
    D_loss_args             = {},       # 判别损失设置。
    dataset_args            = {},       # 数据集设置。
    sched_args              = {},       # 训练计划设置。
    grid_args               = {},       # setup_snapshot_image_grid()相关设置。
    metric_arg_list         = [],       # 指标方法设置。
    tf_config               = {},       # tflib.init_tf()相关设置。
    G_smoothing_kimg        = 10.0,     # 生成器权重的运行平均值的半衰期。
    D_repeats               = 1,        # G每迭代一次训练判别器多少次。
    minibatch_repeats       = 4,        # 调整训练参数前要运行的minibatch的数量。
    reset_opt_for_new_lod   = True,     # 引入新层时是否重置优化器内部状态(例如Adam时刻)?
    total_kimg              = 15000,    # 训练的总长度,以成千上万个真实图像为统计。
    mirror_augment          = False,    # 启用镜像增强?
    drange_net              = [-1,1],   # 将图像数据馈送到网络时使用的动态范围。
    image_snapshot_ticks    = 1,        # 多久导出一次图像快照?
    network_snapshot_ticks  = 10,       # 多久导出一次网络模型存储?
    save_tf_graph           = False,    # 在tfevents文件中包含完整的TensorFlow计算图吗?
    save_weight_histograms  = False,    # 在tfevents文件中包括权重直方图?
    resume_run_id           = None,     # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。
    resume_snapshot         = None,     # 要从哪恢复训练的快照的索引,None = 自动检测。
    resume_kimg             = 0.0,      # 在训练开始时给定当前训练进度。影响报告和训练计划。
    resume_time             = 0.0):     # 在训练开始时给定统计时间。影响报告。

    # 初始化dnnlib和TensorFlow。
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # 载入训练集。
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # 构建网络。
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()
    # 构建计算图与优化器
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    # 设置快照图像网格
    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
    # 建立运行目录
    print('Setting up run dir...')
    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)
    # 训练
    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # 选择训练参数并配置训练操作。
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # 进行训练。
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # 每个tick执行一次维护任务。
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # 报告进度。
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # 保存快照。
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)

            # 更新摘要和RunContext。
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # 保存最终结果。
    misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Example #24
0
def training_loop(
        run_dir='.',  # Output directory.
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        loss_args={},  # Options for loss function.
        train_dataset_args={},  # Options for dataset to train with.
        # Options for dataset to evaluate metrics against.
    metric_dataset_args={},
        augment_args={},  # Options for adaptive augmentations.
        metric_arg_list=[],  # Metrics to evaluate during training.
        num_gpus=1,  # Number of GPUs to use.
        minibatch_size=32,  # Global minibatch size.
        minibatch_gpu=4,  # Number of samples processed at a time by one GPU.
        # Half-life of the exponential moving average (EMA) of generator weights.
    G_smoothing_kimg=10,
        G_smoothing_rampup=None,  # EMA ramp-up coefficient.
        # Number of minibatches to run in the inner loop.
    minibatch_repeats=4,
        lazy_regularization=True,  # Perform regularization as a separate training step?
        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    G_reg_interval=4,
        # How often the perform regularization for D? Ignored if lazy_regularization=False.
        D_reg_interval=16,
        # Total length of the training, measured in thousands of real images.
        total_kimg=25000,
        kimg_per_tick=4,  # Progress snapshot interval.
        # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    image_snapshot_ticks=50,
        # How often to save network snapshots? None = only save 'networks-final.pkl'.
        network_snapshot_ticks=50,
        resume_pkl=None,  # Network pickle to resume training from.
        # Callback function for determining whether to abort training.
    abort_fn=None,
        progress_fn=None,  # Callback function for updating training progress.
):
    assert minibatch_size % (num_gpus * minibatch_gpu) == 0
    start_time = time.time()

    print('Loading training set...')
    training_set = dataset.load_dataset(**train_dataset_args)
    print('Image shape:', np.int32(training_set.shape).tolist())
    print('Label shape:', [training_set.label_size])
    print()

    print('Constructing networks...')
    with tf.device('/gpu:0'):
        G = tflib.Network('G',
                          num_channels=training_set.shape[0],
                          resolution=training_set.shape[1],
                          label_size=training_set.label_size,
                          **G_args)
        D = tflib.Network('D',
                          num_channels=training_set.shape[0],
                          resolution=training_set.shape[1],
                          label_size=training_set.label_size,
                          **D_args)
        Gs = G.clone('Gs')
        if resume_pkl is not None:
            print(f'Resuming from "{resume_pkl}"')
            with dnnlib.util.open_url(resume_pkl) as f:
                rG, rD, rGs = pickle.load(f)
            G.copy_vars_from(rG)
            D.copy_vars_from(rD)
            Gs.copy_vars_from(rGs)
    G.print_layers()
    D.print_layers()

    print('Exporting sample images...')
    grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(
        training_set)
    save_image_grid(grid_reals,
                    os.path.join(run_dir, 'reals.png'),
                    drange=[0, 255],
                    grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=minibatch_gpu)
    # save_image_grid(grid_fakes, os.path.join(
    #     run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size)

    print(f'Replicating networks across {num_gpus} GPUs...')
    G_gpus = [G]
    D_gpus = [D]
    for gpu in range(1, num_gpus):
        with tf.device(f'/gpu:{gpu}'):
            G_gpus.append(G.clone(f'{G.name}_gpu{gpu}'))
            D_gpus.append(D.clone(f'{D.name}_gpu{gpu}'))

    print('Initializing augmentations...')
    aug = None
    if augment_args.get('class_name', None) is not None:
        aug = dnnlib.util.construct_class_by_name(**augment_args)
        aug.init_validation_set(D_gpus=D_gpus, training_set=training_set)

    print('Setting up optimizers...')
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args[
            'minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args:
                args['beta1'] **= mb_ratio
            if 'beta2' in args:
                args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    print('Constructing training graph...')
    data_fetch_ops = []
    training_set.configure(minibatch_gpu)
    for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)):
        with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'):

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                real_images_var = tf.Variable(
                    name='images',
                    trainable=False,
                    initial_value=tf.zeros([minibatch_gpu] +
                                           training_set.shape))
                real_labels_var = tf.Variable(name='labels',
                                              trainable=False,
                                              initial_value=tf.zeros([
                                                  minibatch_gpu,
                                                  training_set.label_size
                                              ]))
                real_images_write, real_labels_write = training_set.get_minibatch_tf(
                )
                real_images_write = tflib.convert_images_from_uint8(
                    real_images_write)
                data_fetch_ops += [
                    tf.assign(real_images_var, real_images_write)
                ]
                data_fetch_ops += [
                    tf.assign(real_labels_var, real_labels_write)
                ]

            # Evaluate loss function and register gradients.
            fake_labels = training_set.get_random_labels_tf(minibatch_gpu)
            terms = dnnlib.util.call_func_by_name(G=G_gpu,
                                                  D=D_gpu,
                                                  aug=aug,
                                                  fake_labels=fake_labels,
                                                  real_images=real_images_var,
                                                  real_labels=real_labels_var,
                                                  **loss_args)
            if lazy_regularization:
                if terms.G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(terms.G_reg * G_reg_interval),
                        G_gpu.trainables)
                if terms.D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(terms.D_reg * D_reg_interval),
                        D_gpu.trainables)
            else:
                if terms.G_reg is not None:
                    terms.G_loss += terms.G_reg
                if terms.D_reg is not None:
                    terms.D_loss += terms.D_reg
            G_opt.register_gradients(tf.reduce_mean(terms.G_loss),
                                     G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(terms.D_loss),
                                     D_gpu.trainables)

    print('Finalizing training ops...')
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[])
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in)
    tflib.init_uninitialized_vars()
    with tf.device('/gpu:0'):
        peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()

    print('Initializing metrics...')
    summary_log = tf.summary.FileWriter(run_dir)
    metrics = []
    for args in metric_arg_list:
        metric = dnnlib.util.construct_class_by_name(**args)
        metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir)
        metrics.append(metric)

    print(f'Training for {total_kimg} kimg...')
    print()
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    running_mb_counter = 0

    done = False
    while not done:

        # Compute EMA decay parameter.
        Gs_nimg = G_smoothing_kimg * 1000.0
        if G_smoothing_rampup is not None:
            Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup)
        Gs_beta = 0.5**(minibatch_size / max(Gs_nimg, 1e-8))

        # Run training ops.
        for _repeat_idx in range(minibatch_repeats):
            rounds = range(0, minibatch_size, minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op])
                if run_G_reg:
                    tflib.run(G_reg_op)
                tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta})
                if run_D_reg:
                    tflib.run(D_reg_op)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op)
                    if run_G_reg:
                        tflib.run(G_reg_op)
                tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta})
                for _round in rounds:
                    tflib.run(data_fetch_op)
                    tflib.run(D_train_op)
                    if run_D_reg:
                        tflib.run(D_reg_op)

            # Run validation.
            if aug is not None:
                aug.run_validation(minibatch_size=minibatch_size)

        # Tune augmentation parameters.
        if aug is not None:
            aug.tune(minibatch_size * minibatch_repeats)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None
                                                   and abort_fn())
        if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_end_time = time.time()
            total_time = tick_end_time - start_time
            tick_time = tick_end_time - tick_start_time

            # Report progress.
            print(' '.join([
                f"tick {autosummary('Progress/tick', cur_tick):<5d}",
                f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}",
                f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}",
                f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}",
                f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}",
                f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}",
                f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}",
                f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}",
            ]))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            if progress_fn is not None:
                progress_fn(cur_nimg // 1000, total_kimg)

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    done or cur_tick % image_snapshot_ticks == 0):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=minibatch_gpu)
                save_image_grid(grid_fakes,
                                os.path.join(
                                    run_dir,
                                    f'fakes{cur_nimg // 1000:06d}.png'),
                                drange=[-1, 1],
                                grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    done or cur_tick % network_snapshot_ticks == 0):
                pkl = os.path.join(
                    run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl')
                with open(pkl, 'wb') as f:
                    pickle.dump((G, D, Gs), f)
                if len(metrics):
                    print('Evaluating metrics...')
                    for metric in metrics:
                        metric.run(pkl, num_gpus=num_gpus)

            # Update summaries.
            for metric in metrics:
                metric.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            tick_start_time = time.time()
            maintenance_time = tick_start_time - tick_end_time

    print()
    print('Exiting...')
    summary_log.close()
    training_set.close()
def training_loop_vae(
        G_args={},  # Options for generator network.
        E_args={},  # Options for encoder network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        G_loss_args={},  # Options for generator loss.
        D_loss_args={},  # Options for discriminator loss.
        dataset_args={},  # Options for dataset.load_dataset().
        sched_args={},  # Options for train.TrainingSchedule.
        grid_args={},  # Options for train.setup_snapshot_image_grid().
        metric_arg_list=[],  # Options for MetricGroup.
        tf_config={},  # Options for tflib.init_tf().
        data_dir=None,  # Directory to load datasets from.
        minibatch_repeats=1,  # Number of minibatches to run before adjusting training parameters.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        mirror_augment=False,  # Enable mirror augment?
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
        save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
        resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
        resume_with_new_nets=False,  # Construct new networks according to G_args and D_args before resuming training?
        traversal_grid=False,  # Used for disentangled representation learning.
        n_discrete=0,  # Number of discrete latents in model.
        n_continuous=4,  # Number of continuous latents in model.
        topk_dims_to_show=20,  # Number of top disentant dimensions to show in a snapshot.
        subgroup_sizes_ls=None,
        subspace_sizes_ls=None,
        forward_eg=False,
        n_samples_per=10):  # Number of samples for each line in traversal.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # If use Discriminator.
    use_D = D_args is not None

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    grid_fakes = add_outline(grid_reals, width=1)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            E = tflib.Network('E',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              input_shape=[None] + training_set.shape,
                              **E_args)
            G = tflib.Network(
                'G',
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                input_shape=[None, n_discrete +
                             G_args.latent_size] if not forward_eg else [
                                 None, n_discrete + G_args.latent_size +
                                 sum(subgroup_sizes_ls)
                             ],
                **G_args)
            if use_D:
                D = tflib.Network('D',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  input_shape=[None, D_args.latent_size],
                                  **D_args)
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if use_D:
                rE, rG, rD = misc.load_pkl(resume_pkl)
            else:
                rE, rG = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                E.copy_vars_from(rE)
                G.copy_vars_from(rG)
                if use_D:
                    D.copy_vars_from(rD)
            else:
                E = rE
                G = rG
                if use_D:
                    D = rD

    # Print layers and generate initial image snapshot.
    E.print_layers()
    G.print_layers()
    if use_D:
        D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    if traversal_grid:
        if topk_dims_to_show > 0:
            topk_dims = np.arange(min(topk_dims_to_show, n_continuous))
        else:
            topk_dims = np.arange(n_continuous)
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims)
    else:
        grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    print('grid_size:', grid_size)
    print('grid_latents.shape:', grid_latents.shape)
    print('grid_labels.shape:', grid_labels.shape)
    grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v(
        G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents,
              grid_labels,
              is_validation=True,
              minibatch_size=sched.minibatch_gpu,
              randomize_noise=True), 8)
    print('Lie_vars:', lie_vars[0])
    grid_fakes = add_outline(grid_fakes, width=1)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    G_opt_args['minibatch_multiplier'] = minibatch_multiplier
    G_opt_args['learning_rate'] = lrate_in
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    if use_D:
        D_opt_args = dict(D_opt_args)
        D_opt_args['minibatch_multiplier'] = minibatch_multiplier
        D_opt_args['learning_rate'] = lrate_in
        D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if use_D:
                D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, 0., mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            if use_D:
                with tf.name_scope('G_loss'):
                    G_loss = dnnlib.util.call_func_by_name(
                        E=E_gpu,
                        G=G_gpu,
                        D=D_gpu,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss = dnnlib.util.call_func_by_name(
                        E=E_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)
            else:
                with tf.name_scope('G_loss'):
                    G_loss = dnnlib.util.call_func_by_name(
                        E=E_gpu,
                        G=G_gpu,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **G_loss_args)

            # Register gradients.
            EG_gpu_trainables = collections.OrderedDict(
                list(E_gpu.trainables.items()) +
                list(G_gpu.trainables.items()))
            G_opt.register_gradients(tf.reduce_mean(G_loss), EG_gpu_trainables)
            # G_opt.register_gradients(G_loss,
            # EG_gpu_trainables)
            if use_D:
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
                # D_opt.register_gradients(D_loss,
                # D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    if use_D:
        D_train_op = D_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        if use_D:
            D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, 0)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op], feed_dict)
                tflib.run([data_fetch_op], feed_dict)
                if use_D:
                    tflib.run([D_train_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    if use_D:
                        tflib.run(D_train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                if use_D:
                    misc.save_pkl((E, G, D), pkl)
                else:
                    misc.save_pkl((E, G), pkl)
                met_outs = metrics.run(pkl,
                                       run_dir=dnnlib.make_run_dir_path(),
                                       data_dir=dnnlib.convert_path(data_dir),
                                       num_gpus=num_gpus,
                                       tf_config=tf_config,
                                       is_vae=True,
                                       use_D=use_D,
                                       Gs_kwargs=dict(is_validation=True))
                if topk_dims_to_show > 0:
                    if 'tpl_per_dim' in met_outs:
                        avg_distance_per_dim = met_outs[
                            'tpl_per_dim']  # shape: (n_continuous)
                        topk_dims = np.argsort(
                            avg_distance_per_dim
                        )[::-1][:topk_dims_to_show]  # shape: (20)
                    else:
                        topk_dims = np.arange(
                            min(topk_dims_to_show, n_continuous))
                else:
                    topk_dims = np.arange(n_continuous)

            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                if traversal_grid:
                    grid_size, grid_latents, grid_labels = get_grid_latents(
                        n_discrete, n_continuous, n_samples_per, G,
                        grid_labels, topk_dims)
                else:
                    grid_latents = np.random.randn(np.prod(grid_size),
                                                   *G.input_shape[1:])

                grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v(
                    G.run(append_gfeats(grid_latents, G)
                          if forward_eg else grid_latents,
                          grid_labels,
                          is_validation=True,
                          minibatch_size=sched.minibatch_gpu,
                          randomize_noise=True), 8)
                print('Lie_vars:', lie_vars[0])
                grid_fakes = add_outline(grid_fakes, width=1)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    if use_D:
        misc.save_pkl((E, G, D), dnnlib.make_run_dir_path('network-final.pkl'))
    else:
        misc.save_pkl((E, G), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Example #26
0
    def set_network(self, Gs, dtype='float16'):
        if Gs is None:
            self._Gs = None
            return
        self._Gs = Gs.clone(randomize_noise=False,
                            dtype=dtype,
                            num_fp16_res=0,
                            fused_modconv=True)

        # Compute dlatent stats.
        self._info(
            f'Computing W midpoint and stddev using {self.dlatent_avg_samples} samples...'
        )
        latent_samples = np.random.RandomState(123).randn(
            self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
        dlatent_samples = self._Gs.components.mapping.run(
            latent_samples, None)  # [N, L, C]
        dlatent_samples = dlatent_samples[:, :1, :].astype(
            np.float32)  # [N, 1, C]
        self._dlatent_avg = np.mean(dlatent_samples, axis=0,
                                    keepdims=True)  # [1, 1, C]
        self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg)**2) /
                             self.dlatent_avg_samples)**0.5
        self._info(f'std = {self._dlatent_std:g}')

        # Setup noise inputs.
        self._info('Setting up noise inputs...')
        self._noise_vars = []
        noise_init_ops = []
        noise_normalize_ops = []
        while True:
            n = f'G_synthesis/noise{len(self._noise_vars)}'
            if not n in self._Gs.vars:
                break
            v = self._Gs.vars[n]
            self._noise_vars.append(v)
            noise_init_ops.append(
                tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32)))
            noise_mean = tf.reduce_mean(v)
            noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5
            noise_normalize_ops.append(
                tf.assign(v, (v - noise_mean) / noise_std))
        self._noise_init_op = tf.group(*noise_init_ops)
        self._noise_normalize_op = tf.group(*noise_normalize_ops)

        # Build image output graph.
        self._info('Building image output graph...')
        self._minibatch_size = 1
        self._dlatents_var = tf.Variable(
            tf.zeros([self._minibatch_size] +
                     list(self._dlatent_avg.shape[1:])),
            name='dlatents_var')
        self._dlatent_noise_in = tf.placeholder(tf.float32, [],
                                                name='noise_in')
        dlatents_noise = tf.random.normal(
            shape=self._dlatents_var.shape) * self._dlatent_noise_in
        self._dlatents_expr = tf.tile(
            self._dlatents_var + dlatents_noise,
            [1, self._Gs.components.synthesis.input_shape[1], 1])
        self._images_float_expr = tf.cast(
            self._Gs.components.synthesis.get_output_for(self._dlatents_expr),
            tf.float32)
        self._images_uint8_expr = tflib.convert_images_to_uint8(
            self._images_float_expr, nchw_to_nhwc=True)

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        proc_images_expr = (self._images_float_expr + 1) * (255 / 2)
        sh = proc_images_expr.shape.as_list()
        if sh[2] > 256:
            factor = sh[2] // 256
            proc_images_expr = tf.reduce_mean(tf.reshape(
                proc_images_expr,
                [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]),
                                              axis=[3, 5])

        # Build loss graph.
        self._info('Building loss graph...')
        self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape),
                                              name='target_images_var')
        if self._lpips is None:
            with dnnlib.util.open_url(
                    'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl'
            ) as f:
                self._lpips = pickle.load(f)
        self._dist = self._lpips.get_output_for(proc_images_expr,
                                                self._target_images_var)
        self._loss = tf.reduce_sum(self._dist)

        # Build noise regularization graph.
        self._info('Building noise regularization graph...')
        reg_loss = 0.0
        for v in self._noise_vars:
            sz = v.shape[2]
            while True:
                reg_loss += tf.reduce_mean(
                    v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(
                        v * tf.roll(v, shift=1, axis=2))**2
                if sz <= 8:
                    break  # Small enough already
                v = tf.reshape(v, [1, 1, sz // 2, 2, sz // 2, 2])  # Downscale
                v = tf.reduce_mean(v, axis=[3, 5])
                sz = sz // 2
        self._loss += reg_loss * self.regularize_noise_weight

        # Setup optimizer.
        self._info('Setting up optimizer...')
        self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in')
        self._opt = tflib.Optimizer(learning_rate=self._lrate_in)
        self._opt.register_gradients(self._loss,
                                     [self._dlatents_var] + self._noise_vars)
        self._opt_step = self._opt.apply_updates()
def training_loop_infernet(
    I_args={},  # Options for infogan-head/vcgan-head network.
    I_opt_args={},  # Options for discriminator optimizer.
    loss_args={},  # Options for discriminator loss.
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=5,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    G_pkl=None,  # The G to load.
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,  # Construct new networks according to G_args and D_args before resuming training?
    n_samples_per=10):  # Number of samples for each line in traversal.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Construct or load networks.
    with tf.device('/gpu:0'):
        G, D, I, Gs = misc.load_pkl(G_pkl)
        print('Gs.output_shapes:', Gs.output_shapes)
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            I = tflib.Network('I',
                              num_channels=Gs.output_shapes[0][1],
                              resolution=Gs.output_shapes[0][2],
                              **I_args)
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            rI, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                I.copy_vars_from(rI)
                Gs.copy_vars_from(rGs)
            else:
                I = rI
                Gs = rGs

    # Print layers and generate initial image snapshot.
    Gs.print_layers()
    I.print_layers()

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    I_opt_args = dict(I_opt_args)
    I_opt_args['minibatch_multiplier'] = minibatch_multiplier
    I_opt_args['learning_rate'] = lrate_in
    I_opt = tflib.Optimizer(name='TrainI', **I_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')

            # Evaluate loss functions.
            with tf.name_scope('I_loss'):
                loss, reg = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    I=I_gpu,
                    opt=I_opt,
                    minibatch_size=minibatch_gpu_in,
                    **loss_args)

            if reg is not None: loss += reg

            # Register gradients.
            I_opt.register_gradients(tf.reduce_mean(loss), I_gpu.trainables)

    # Setup training ops.
    I_train_op = I_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        I.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        assert sched_args.minibatch_size % (sched_args.minibatch_gpu *
                                            num_gpus) == 0

        # Run training ops.
        feed_dict = {
            lrate_in: sched_args.lrate,
            minibatch_size_in: sched_args.minibatch_size,
            minibatch_gpu_in: sched_args.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched_args.minibatch_size,
                           sched_args.minibatch_gpu * num_gpus)
            cur_nimg += sched_args.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([I_train_op], feed_dict)
            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(I_train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                %
                (autosummary('Progress/tick', cur_tick),
                 autosummary('Progress/kimg', cur_nimg / 1000.0),
                 autosummary('Progress/minibatch', sched_args.minibatch_size),
                 dnnlib.util.format_time(
                     autosummary('Timing/total_sec', total_time)),
                 autosummary('Timing/sec_per_tick', tick_time),
                 autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                 autosummary('Timing/maintenance_sec', maintenance_time),
                 autosummary('Resources/peak_gpu_mem_gb',
                             peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((I, G), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            num_gpus=num_gpus,
                            tf_config=tf_config,
                            train_infernet=True)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update(cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((I, G), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
Example #28
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=10000.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            #print('Constructing networks...')
            #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            #Gs = G.clone('Gs')
            url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'
            with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
                G, D, Gs = pickle.load(f)
            print('Loading pretrained FFHQ network')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)

    cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' %
                                      resume_kimg) + "  gs://stylegan_out"
    response = subprocess.run(cmd, shell=True)

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                cmd = "gsutil cp " + os.path.join(
                    submit_config.run_dir, 'fakes%06d.png' %
                    (cur_nimg // 1000)) + "  gs://stylegan_out"
                response = subprocess.run(cmd, shell=True)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Example #29
0
def training_loop(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(
        grid_reals,
        dnnlib.make_run_dir_path("reals.png"),
        drange=training_set.dynamic_range,
        grid_size=grid_size,
    )

    # Construct or load networks.
    with tf.device("/gpu:0"):
        if resume_pkl is None or resume_with_new_nets:
            print("Constructing networks...")
            G = tflib.Network("G",
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network("D",
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone("Gs")
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(
        grid_latents,
        grid_labels,
        is_validation=True,
        minibatch_size=sched.minibatch_gpu,
    )
    misc.save_image_grid(
        grid_fakes,
        dnnlib.make_run_dir_path("fakes_init.png"),
        drange=drange_net,
        grid_size=grid_size,
    )

    # Setup training inputs.
    print("Building TensorFlow graph...")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lod_in = tf.placeholder(tf.float32, name="lod_in", shape=[])
        lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name="minibatch_size_in",
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name="minibatch_gpu_in",
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = (0.5**tf.div(tf.cast(minibatch_size_in,
                                       tf.float32), G_smoothing_kimg *
                               1000.0) if G_smoothing_kimg > 0.0 else 0.0)

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [
        (G_opt_args, G_reg_interval),
        (D_opt_args, D_reg_interval),
    ]:
        args["minibatch_multiplier"] = minibatch_multiplier
        args["learning_rate"] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args["learning_rate"] *= mb_ratio
            if "beta1" in args:
                args["beta1"] **= mb_ratio
            if "beta2" in args:
                args["beta2"] **= mb_ratio
    G_opt = tflib.Optimizer(name="TrainG", **G_opt_args)
    D_opt = tflib.Optimizer(name="TrainD", **D_opt_args)
    G_reg_opt = tflib.Optimizer(name="RegG", share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name="RegD", share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow")
            D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow")

            # Fetch training data via temporary variables.
            with tf.name_scope("DataFetch"):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name="reals",
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape),
                )
                labels_var = tf.Variable(
                    name="labels",
                    trainable=False,
                    initial_value=tf.zeros(
                        [sched.minibatch_gpu, training_set.label_size]),
                )
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write,
                    labels_write,
                    lod_in,
                    mirror_augment,
                    training_set.dynamic_range,
                    drange_net,
                )
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if "lod" in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars["lod"], lod_in)]
            if "lod" in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars["lod"], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope("G_loss"):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        **G_loss_args)
                with tf.name_scope("D_loss"):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None:
                    G_loss += G_reg
                if D_reg is not None:
                    D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print("Initializing logs...")
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print("Training for %d kimg...\n" % total_kimg)
    dnnlib.RunContext.get().update("",
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = lazy_regularization and running_mb_counter % G_reg_interval == 0
            run_D_reg = lazy_regularization and running_mb_counter % D_reg_interval == 0
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = cur_nimg >= total_kimg * 1000
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                "tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f"
                % (
                    autosummary("Progress/tick", cur_tick),
                    autosummary("Progress/kimg", cur_nimg / 1000.0),
                    autosummary("Progress/lod", sched.lod),
                    autosummary("Progress/minibatch", sched.minibatch_size),
                    dnnlib.util.format_time(
                        autosummary("Timing/total_sec", total_time)),
                    autosummary("Timing/sec_per_tick", tick_time),
                    autosummary("Timing/sec_per_kimg", tick_time / tick_kimg),
                    autosummary("Timing/maintenance_sec", maintenance_time),
                    autosummary("Resources/peak_gpu_mem_gb",
                                peak_gpu_mem_op.eval() / 2**30),
                ))
            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(
                    grid_latents,
                    grid_labels,
                    is_validation=True,
                    minibatch_size=sched.minibatch_gpu,
                )
                misc.save_image_grid(
                    grid_fakes,
                    dnnlib.make_run_dir_path("fakes%06d.png" %
                                             (cur_nimg // 1000)),
                    drange=drange_net,
                    grid_size=grid_size,
                )
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(
                    pkl,
                    run_dir=dnnlib.make_run_dir_path(),
                    data_dir=dnnlib.convert_path(data_dir),
                    num_gpus=num_gpus,
                    tf_config=tf_config,
                )

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update("%.2f" % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = (
                dnnlib.RunContext.get().get_last_update_interval() - tick_time)

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path("network-final.pkl"))

    # All done.
    summary_log.close()
    training_set.close()