Esempio n. 1
0
def show_real_data(data_dir, dataset_name, number):
    tflib.init_tf()
    dataset_args = EasyDict(tfrecord_dir=dataset_name, max_label_size='full')
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    gw = 1
    gh = 1
    for i in range(number):
        reals, _ = training_set.get_minibatch_np(gw * gh)
        misc.save_image_grid(reals,
                             dnnlib.make_run_dir_path('reals%04d.png' % (i)),
                             drange=training_set.dynamic_range,
                             grid_size=None)
def training_loop(
    G_args                  = {},       # Options for generator network.
    D_args                  = {},       # Options for discriminator network.
    G_opt_args              = {},       # Options for generator optimizer.
    D_opt_args              = {},       # Options for discriminator optimizer.
    G_loss_args             = {},       # Options for generator loss.
    D_loss_args             = {},       # Options for discriminator loss.
    dataset_args            = {},       # Options for dataset.load_dataset().
    sched_args              = {},       # Options for train.TrainingSchedule.
    grid_args               = {},       # Options for train.setup_snapshot_image_grid().
    metric_arg_list         = [],       # Options for MetricGroup.
    tf_config               = {},       # Options for tflib.init_tf().
    data_dir                = None,     # Directory to load datasets from.
    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.
    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.
    lazy_regularization     = True,     # Perform regularization as a separate training step?
    G_reg_interval          = 4,        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval          = 16,       # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_pkl              = None,     # Network pickle to resume training from, None = train from scratch.
    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0,      # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets    = False):   # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
    misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)

    # Construct or load networks.
    with tfex.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs)
            else: G = rG; D = rD; Gs = rGs

    # Print layers and generate initial image snapshot.
    G.print_layers(); D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tfex.device('/cpu:0'):
        lod_in               = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in             = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in    = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[])
        minibatch_gpu_in     = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus)
        Gs_beta              = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tfex.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args)
                reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
                labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
                reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables)
                if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tfex.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu}
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

            # Report progress.
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch_size),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir), num_gpus=num_gpus, tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 3
0
def training_loop(
    classifier_args={},  # Options for generator network.
    classifier_opt_args={},  # Options for generator optimizer.
    classifier_loss_args={},
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    network_snapshot_ticks=5,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False):

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        shuffle_mb=2 * 4096,
                                        **dataset_args)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        print('Constructing networks...')
        classifier = tflib.Network('classifier',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   label_size=training_set.label_size,
                                   **classifier_args)

    classifier.print_layers()

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    classifier_opt_args = dict(classifier_opt_args)

    classifier_opt_args['minibatch_multiplier'] = minibatch_multiplier
    classifier_opt_args['learning_rate'] = lrate_in

    classifier_opt = tflib.Optimizer(name='TrainClassifier',
                                     **classifier_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('gpu%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            classifier_gpu = classifier if gpu == 0 else classifier.clone(
                classifier.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=0, **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros(
                                             [sched.minibatch_gpu, 127]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            with tf.name_scope('classifier_loss'):
                classifier_loss, label = dnnlib.util.call_func_by_name(
                    classifier=classifier_gpu,
                    images=reals_read,
                    labels=labels_read,
                    **classifier_loss_args)

            classifier_opt.register_gradients(tf.reduce_mean(classifier_loss),
                                              classifier_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    classifier_train_op = classifier_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())

    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', cur_epoch=0, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([classifier_train_op, data_fetch_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    classifier_loss_out, label_out, _ = tflib.run(
                        [classifier_loss, label, classifier_train_op],
                        feed_dict)
                    print_output = False
                    if print_output:
                        print('label')
                        print(np.round(label_out, 2))
                        print('loss')
                        print(np.round(classifier_loss_out, 2))
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start()

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl(classifier, pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl(classifier, dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 4
0
def training_loop_refinement(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=True,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            _rG, _rD, rGs = misc.load_pkl(resume_pkl)
            del _rD, _rG
            if resume_with_new_nets:
                G.copy_vars_from(rGs)
                Gs.copy_vars_from(rGs)
                del rGs
            else:
                G = rG
                Gs = rGs

    # Set constant noise input for both G and Gs
    if G_args.get("randomize_noise", None) == False:
        noise_vars = [
            var for name, var in G.components.synthesis.vars.items()
            if name.startswith('noise')
        ]
        rnd = np.random.RandomState(123)
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]

        noise_vars = [
            var for name, var in Gs.components.synthesis.vars.items()
            if name.startswith('noise')
        ]
        rnd = np.random.RandomState(123)
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]

    # TESTS
    # from PIL import Image
    # reals, latents = training_set.get_minibatch_np(4)
    # reals = np.transpose(reals, [0, 2, 3, 1])
    # Image.fromarray(reals[0], 'RGB').save("test_reals.png")

    # labels = training_set.get_random_labels_np(4)
    # Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs)
    # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png")
    # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs)
    # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png")

    # Print layers and generate initial image snapshot.
    G.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)

    # Freeze layers
    G_args.freeze_layers = list(G_args.get("freeze_layers", []))

    def freeze_vars(gen, verbose=True):
        assert len(G_args.freeze_layers) > 0
        for name in list(gen.trainables.keys()):
            if any(layer in name for layer in G_args.freeze_layers):
                del gen.trainables[name]
                if verbose: print(f"Freezed {name}")

    # Build training graph for each GPU.
    data_fetch_ops = []
    loss_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False)

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=None,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        latents=labels_read,
                        **G_loss_args)
                    loss_ops.append(G_loss)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0))
    G_train_op = G_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    loss_per_batch_sum = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        tflib.run(data_fetch_op, feed_dict)
        ### TEST
        # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5]
        # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1])
        # reals = reals_read
        ### TEST
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                loss, _ = tflib.run([loss_op, G_train_op], feed_dict)
                # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict)
                tflib.run([data_fetch_op], feed_dict)
                # print(f"loss_tf  {np.mean(loss)}")
                # print(f"loss_np  {np.mean(np.square(reals - fakes))}")
                # print(f"loss_abs {np.mean(np.abs(reals - fakes))}")

                loss_per_batch_sum += loss
                #### TEST ####
                # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0:
                #     from PIL import Image
                #     reals = np.transpose(reals, [0, 2, 3, 1])
                #     fakes = np.transpose(fakes, [0, 2, 3, 1])
                #     diff = np.abs(reals - fakes)
                #     print(diff.min(), diff.max())
                #     for idx, (fake, real) in enumerate(zip(fakes, reals)):
                #         fake -= fake.min()
                #         fake /= fake.max()
                #         fake *= 255
                #         fake = fake.astype(np.uint8)
                #         Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png")
                #         real -= real.min()
                #         real /= real.max()
                #         real *= 255
                #         real = real.astype(np.uint8)
                #         Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png")
                ####
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([Gs_update_op], feed_dict)

            # Slow path with gradient accumulation. FIXME: Probably wrong
            else:
                for _round in rounds:
                    loss, _, _ = tflib.run(
                        [loss_op, G_train_op, data_fetch_op], feed_dict)
                    loss_per_batch_sum += loss / len(rounds)
                    if run_G_reg:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time
            tick_loss = loss_per_batch_sum * sched.minibatch_size / (
                tick_kimg * 1000)
            loss_per_batch_sum = 0

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   autosummary('Progress/loss_per_px', tick_loss),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, None, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 5
0
def training_loop_vae(
        G_args={},  # Options for generator network.
        E_args={},  # Options for encoder network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        G_loss_args={},  # Options for generator loss.
        D_loss_args={},  # Options for discriminator loss.
        dataset_args={},  # Options for dataset.load_dataset().
        sched_args={},  # Options for train.TrainingSchedule.
        grid_args={},  # Options for train.setup_snapshot_image_grid().
        metric_arg_list=[],  # Options for MetricGroup.
        tf_config={},  # Options for tflib.init_tf().
        data_dir=None,  # Directory to load datasets from.
        minibatch_repeats=1,  # Number of minibatches to run before adjusting training parameters.
        total_kimg=25000,  # Total length of the training, measured in thousands of real images.
        mirror_augment=False,  # Enable mirror augment?
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
        network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
        save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
        resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
        resume_with_new_nets=False,  # Construct new networks according to G_args and D_args before resuming training?
        traversal_grid=False,  # Used for disentangled representation learning.
        n_discrete=0,  # Number of discrete latents in model.
        n_continuous=4,  # Number of continuous latents in model.
        topk_dims_to_show=20,  # Number of top disentant dimensions to show in a snapshot.
        subgroup_sizes_ls=None,
        subspace_sizes_ls=None,
        forward_eg=False,
        n_samples_per=10):  # Number of samples for each line in traversal.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # If use Discriminator.
    use_D = D_args is not None

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    grid_fakes = add_outline(grid_reals, width=1)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            E = tflib.Network('E',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              input_shape=[None] + training_set.shape,
                              **E_args)
            G = tflib.Network(
                'G',
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                input_shape=[None, n_discrete +
                             G_args.latent_size] if not forward_eg else [
                                 None, n_discrete + G_args.latent_size +
                                 sum(subgroup_sizes_ls)
                             ],
                **G_args)
            if use_D:
                D = tflib.Network('D',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  input_shape=[None, D_args.latent_size],
                                  **D_args)
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if use_D:
                rE, rG, rD = misc.load_pkl(resume_pkl)
            else:
                rE, rG = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                E.copy_vars_from(rE)
                G.copy_vars_from(rG)
                if use_D:
                    D.copy_vars_from(rD)
            else:
                E = rE
                G = rG
                if use_D:
                    D = rD

    # Print layers and generate initial image snapshot.
    E.print_layers()
    G.print_layers()
    if use_D:
        D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    if traversal_grid:
        if topk_dims_to_show > 0:
            topk_dims = np.arange(min(topk_dims_to_show, n_continuous))
        else:
            topk_dims = np.arange(n_continuous)
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims)
    else:
        grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    print('grid_size:', grid_size)
    print('grid_latents.shape:', grid_latents.shape)
    print('grid_labels.shape:', grid_labels.shape)
    grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v(
        G.run(append_gfeats(grid_latents, G) if forward_eg else grid_latents,
              grid_labels,
              is_validation=True,
              minibatch_size=sched.minibatch_gpu,
              randomize_noise=True), 8)
    print('Lie_vars:', lie_vars[0])
    grid_fakes = add_outline(grid_fakes, width=1)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    G_opt_args['minibatch_multiplier'] = minibatch_multiplier
    G_opt_args['learning_rate'] = lrate_in
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    if use_D:
        D_opt_args = dict(D_opt_args)
        D_opt_args['minibatch_multiplier'] = minibatch_multiplier
        D_opt_args['learning_rate'] = lrate_in
        D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if use_D:
                D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, 0., mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            if use_D:
                with tf.name_scope('G_loss'):
                    G_loss = dnnlib.util.call_func_by_name(
                        E=E_gpu,
                        G=G_gpu,
                        D=D_gpu,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss = dnnlib.util.call_func_by_name(
                        E=E_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)
            else:
                with tf.name_scope('G_loss'):
                    G_loss = dnnlib.util.call_func_by_name(
                        E=E_gpu,
                        G=G_gpu,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **G_loss_args)

            # Register gradients.
            EG_gpu_trainables = collections.OrderedDict(
                list(E_gpu.trainables.items()) +
                list(G_gpu.trainables.items()))
            G_opt.register_gradients(tf.reduce_mean(G_loss), EG_gpu_trainables)
            # G_opt.register_gradients(G_loss,
            # EG_gpu_trainables)
            if use_D:
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
                # D_opt.register_gradients(D_loss,
                # D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    if use_D:
        D_train_op = D_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        if use_D:
            D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, 0)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op], feed_dict)
                tflib.run([data_fetch_op], feed_dict)
                if use_D:
                    tflib.run([D_train_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    if use_D:
                        tflib.run(D_train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                if use_D:
                    misc.save_pkl((E, G, D), pkl)
                else:
                    misc.save_pkl((E, G), pkl)
                met_outs = metrics.run(pkl,
                                       run_dir=dnnlib.make_run_dir_path(),
                                       data_dir=dnnlib.convert_path(data_dir),
                                       num_gpus=num_gpus,
                                       tf_config=tf_config,
                                       is_vae=True,
                                       use_D=use_D,
                                       Gs_kwargs=dict(is_validation=True))
                if topk_dims_to_show > 0:
                    if 'tpl_per_dim' in met_outs:
                        avg_distance_per_dim = met_outs[
                            'tpl_per_dim']  # shape: (n_continuous)
                        topk_dims = np.argsort(
                            avg_distance_per_dim
                        )[::-1][:topk_dims_to_show]  # shape: (20)
                    else:
                        topk_dims = np.arange(
                            min(topk_dims_to_show, n_continuous))
                else:
                    topk_dims = np.arange(n_continuous)

            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                if traversal_grid:
                    grid_size, grid_latents, grid_labels = get_grid_latents(
                        n_discrete, n_continuous, n_samples_per, G,
                        grid_labels, topk_dims)
                else:
                    grid_latents = np.random.randn(np.prod(grid_size),
                                                   *G.input_shape[1:])

                grid_fakes, _, _, _, _, _, _, lie_vars = get_return_v(
                    G.run(append_gfeats(grid_latents, G)
                          if forward_eg else grid_latents,
                          grid_labels,
                          is_validation=True,
                          minibatch_size=sched.minibatch_gpu,
                          randomize_noise=True), 8)
                print('Lie_vars:', lie_vars[0])
                grid_fakes = add_outline(grid_fakes, width=1)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    if use_D:
        misc.save_pkl((E, G, D), dnnlib.make_run_dir_path('network-final.pkl'))
    else:
        misc.save_pkl((E, G), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 6
0
def training_loop(
    # Configurations
    cG={},
    cD={},  # Generator and Discriminator command-line arguments
    dataset_args={},  # dataset.load_dataset() options
    sched_args={},  # train.TrainingSchedule options
    vis_args={},  # vis.eval options
    grid_args={},  # train.setup_snapshot_img_grid() options
    metric_arg_list=[],  # MetricGroup Options
    tf_config={},  # tflib.init_tf() options
    eval=False,  # Evaluation mode
    train=False,  # Training mode
    # Data
    data_dir=None,  # Directory to load datasets from
    total_kimg=25000,  # Total length of the training, measured in thousands of real images
    mirror_augment=False,  # Enable mirror augmentation?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks
    ratio=1.0,  # Image height/width ratio in the dataset
    # Optimization
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters
    lazy_regularization=True,  # Perform regularization as a separate training step?
    smoothing_kimg=10.0,  # Half-life of the running average of generator weights
    clip=None,  # Clip gradients threshold
    # Resumption
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning
    # Affects reporting and training schedule
    resume_time=0.0,  # Assumed wallclock time at the beginning, affects reporting
    recompile=False,  # Recompile network from source code (otherwise loads from snapshot)
    # Logging
    summarize=True,  # Create TensorBoard summaries
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    img_snapshot_ticks=3,  # How often to save image snapshots? None = disable
    network_snapshot_ticks=3,  # How often to save network snapshots? None = only save networks-final.pkl
    last_snapshots=10,  # Maximal number of prior snapshots to save
    eval_images_num=50000,  # Sample size for the metrics
    printname="",  # Experiment name for logging
    # Architecture
    merge=False):  # Generate several images and then merge them

    # Initialize dnnlib and TensorFlow
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus
    cG.name, cD.name = "g", "d"

    # Load dataset, configure training scheduler and metrics object
    dataset = data.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                verbose=True,
                                **dataset_args)
    sched = training_schedule(sched_args,
                              cur_nimg=total_kimg * 1000,
                              dataset=dataset)
    metrics = metric_base.MetricGroup(metric_arg_list)

    # Construct or load networks
    with tf.device("/gpu:0"):
        no_op = tf.no_op()
        G, D, Gs = None, None, None
        if resume_pkl is None or recompile:
            misc.log("Constructing networks...", "white")
            G = tflib.Network("G",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cG.args)
            D = tflib.Network("D",
                              num_channels=dataset.shape[0],
                              resolution=dataset.shape[1],
                              label_size=dataset.label_size,
                              **cD.args)
            Gs = G.clone("Gs")
        if resume_pkl is not None:
            G, D, Gs = load_nets(resume_pkl, G, D, Gs, recompile)

    G.print_layers()
    D.print_layers()

    # Train/Evaluate/Visualize
    # Labels are optional but not essential
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_img_grid(
        dataset, **grid_args)
    misc.save_img_grid(grid_reals,
                       dnnlib.make_run_dir_path("reals.png"),
                       drange=dataset.dynamic_range,
                       grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])

    if eval:
        # Save a snapshot of the current network to evaluate
        pkl = dnnlib.make_run_dir_path("network-eval-snapshot-%06d.pkl" %
                                       resume_kimg)
        misc.save_pkl((G, D, Gs), pkl, remove=False)

        # Quantitative evaluation
        metric = metrics.run(pkl,
                             num_imgs=eval_images_num,
                             run_dir=dnnlib.make_run_dir_path(),
                             data_dir=dnnlib.convert_path(data_dir),
                             num_gpus=num_gpus,
                             ratio=ratio,
                             tf_config=tf_config,
                             mirror_augment=mirror_augment)

        # Qualitative evaluation
        visualize.eval(G,
                       dataset,
                       batch_size=sched.minibatch_gpu,
                       drange_net=drange_net,
                       ratio=ratio,
                       **vis_args)

    if not train:
        dataset.close()
        exit()

    # Setup training inputs
    misc.log("Building TensorFlow graph...", "white")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in_g = tf.placeholder(tf.float32, name="lrate_in_g", shape=[])
        lrate_in_d = tf.placeholder(tf.float32, name="lrate_in_d", shape=[])
        step = tf.placeholder(tf.int32, name="step", shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name="minibatch_size_in",
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name="minibatch_gpu_in",
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                   tf.float32), smoothing_kimg *
                           1000.0) if smoothing_kimg > 0.0 else 0.0

    # Set optimizers
    for cN, lr in [(cG, lrate_in_g), (cD, lrate_in_d)]:
        set_optimizer(cN, lr, minibatch_multiplier, lazy_regularization, clip)

    # Build training graph for each GPU
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):

            # Create GPU-specific shadow copies of G and D
            for cN, N in [(cG, G), (cD, D)]:
                cN.gpu = N if gpu == 0 else N.clone(N.name + "_shadow")
            Gs_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + "_shadow")

            # Fetch training data via temporary variables
            with tf.name_scope("DataFetch"):
                reals, labels = dataset.get_minibatch_tf()
                reals = process_reals(reals, dataset.dynamic_range, drange_net,
                                      mirror_augment)
                reals, reals_fetch = read_data(
                    reals, "reals", [sched.minibatch_gpu] + dataset.shape,
                    minibatch_gpu_in)
                labels, labels_fetch = read_data(
                    labels, "labels",
                    [sched.minibatch_gpu, dataset.label_size],
                    minibatch_gpu_in)
                data_fetch_ops += [reals_fetch, labels_fetch]

            # Evaluate loss functions
            with tf.name_scope("G_loss"):
                cG.loss, cG.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    minibatch_size=minibatch_gpu_in,
                    **cG.loss_args)

            with tf.name_scope("D_loss"):
                cD.loss, cD.reg = dnnlib.util.call_func_by_name(
                    G=cG.gpu,
                    D=cD.gpu,
                    dataset=dataset,
                    reals=reals,
                    labels=labels,
                    minibatch_size=minibatch_gpu_in,
                    **cD.loss_args)

            for cN in [cG, cD]:
                set_optimizer_ops(cN, lazy_regularization, no_op)

    # Setup training ops
    data_fetch_op = tf.group(*data_fetch_ops)
    for cN in [cG, cD]:
        cN.train_op = cN.opt.apply_updates()
        cN.reg_op = cN.reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=beta)

    # Finalize graph
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # Tensorboard summaries
    if summarize:
        misc.log("Initializing logs...", "white")
        summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
        if save_tf_graph:
            summary_log.add_graph(tf.get_default_graph())
        if save_weight_histograms:
            G.setup_weight_histograms()
            D.setup_weight_histograms()

    # Initialize training
    misc.log("Training for %d kimg..." % total_kimg, "white")
    dnnlib.RunContext.get().update("",
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()

    cur_tick, running_mb_counter = -1, 0
    cur_nimg = int(resume_kimg * 1000)
    tick_start_nimg = cur_nimg
    for cN in [cG, cD]:
        cN.lossvals_agg = {
            k: None
            for k in ["loss", "reg", "norm", "reg_norm"]
        }
        cN.opt.reset_optimizer_state()

    # Training loop
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops
        sched = training_schedule(sched_args,
                                  cur_nimg=cur_nimg,
                                  dataset=dataset)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        dataset.configure(sched.minibatch_gpu)

        # Run training ops
        feed_dict = {
            lrate_in_g: sched.G_lrate,
            lrate_in_d: sched.D_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            step: sched.kimg
        }

        # Several iterations before updating training parameters
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            for cN in [cG, cD]:
                cN.run_reg = lazy_regularization and (running_mb_counter %
                                                      cN.reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            for cN in [cG, cD]:
                cN.lossvals = {
                    k: None
                    for k in ["loss", "reg", "norm", "reg_norm"]
                }

            # Gradient accumulation
            for _round in rounds:
                cG.lossvals.update(
                    tflib.run([cG.train_op, cG.ops], feed_dict)[1])
                if cG.run_reg:
                    _, cG.lossvals["reg_norm"] = tflib.run(
                        [cG.reg_op, cG.reg_norm], feed_dict)

                tflib.run(data_fetch_op, feed_dict)

                cD.lossvals.update(
                    tflib.run([cD.train_op, cD.ops], feed_dict)[1])
                if cD.run_reg:
                    _, cD.lossvals["reg_norm"] = tflib.run(
                        [cD.reg_op, cD.reg_norm], feed_dict)

            tflib.run([Gs_update_op], feed_dict)

            # Track loss statistics
            for cN in [cG, cD]:
                for k in cN.lossvals_agg:
                    cN.lossvals_agg[k] = emaAvg(cN.lossvals_agg[k],
                                                cN.lossvals[k])

        # Perform maintenance tasks once per tick
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress
            print(
                ("tick %s kimg %s   loss/reg: G (%s %s) D (%s %s)   grad norms: G (%s %s) D (%s %s)   "
                 + "time %s sec/kimg %s maxGPU %sGB %s") %
                (misc.bold("%-5d" % autosummary("Progress/tick", cur_tick)),
                 misc.bcolored(
                     "{:>8.1f}".format(
                         autosummary("Progress/kimg", cur_nimg / 1000.0)),
                     "red"),
                 misc.bcolored("{:>6.3f}".format(cG.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cG.lossvals_agg["reg"] or 0)),
                 misc.bcolored("{:>6.3f}".format(cD.lossvals_agg["loss"] or 0),
                               "blue"),
                 misc.bold("{:>6.3f}".format(cD.lossvals_agg["reg"] or 0)),
                 misc.cond_bcolored(cG.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cG.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["norm"], 20.0, "red"),
                 misc.cond_bcolored(cD.lossvals_agg["reg_norm"], 20.0, "red"),
                 misc.bold("%-10s" % dnnlib.util.format_time(
                     autosummary("Timing/total_sec", total_time))),
                 "{:>7.2f}".format(
                     autosummary("Timing/sec_per_kimg",
                                 tick_time / tick_kimg)),
                 "{:>4.1f}".format(
                     autosummary("Resources/peak_gpu_mem_gb",
                                 peak_gpu_mem_op.eval() / 2**30)), printname))

            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots
            if img_snapshot_ticks is not None and (
                    cur_tick % img_snapshot_ticks == 0 or done):
                visualize.eval(G,
                               dataset,
                               batch_size=sched.minibatch_gpu,
                               training=True,
                               step=cur_nimg // 1000,
                               grid_size=grid_size,
                               latents=grid_latents,
                               labels=grid_labels,
                               drange_net=drange_net,
                               ratio=ratio,
                               **vis_args)

            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path("network-snapshot-%06d.pkl" %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl, remove=False)

                if cur_tick % network_snapshot_ticks == 0 or done:
                    metric = metrics.run(
                        pkl,
                        num_imgs=eval_images_num,
                        run_dir=dnnlib.make_run_dir_path(),
                        data_dir=dnnlib.convert_path(data_dir),
                        num_gpus=num_gpus,
                        ratio=ratio,
                        tf_config=tf_config,
                        mirror_augment=mirror_augment)

                if last_snapshots > 0:
                    misc.rm(
                        sorted(
                            glob.glob(dnnlib.make_run_dir_path(
                                "network*.pkl")))[:-last_snapshots])

            # Update summaries and RunContext
            if summarize:
                metrics.update_autosummaries()
                tflib.autosummary.save_summaries(summary_log, cur_nimg)

            dnnlib.RunContext.get().update(None,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot
    misc.save_pkl((G, D, Gs),
                  dnnlib.make_run_dir_path("network-final.pkl"),
                  remove=False)

    # All done
    if summarize:
        summary_log.close()
    dataset.close()
Esempio n. 7
0
def training_loop_hd(
    I_args={},  # Options for generator network.
    M_args={},  # Options for discriminator network.
    I_info_args={},  # Options for class network.
    I_opt_args={},  # Options for generator optimizer.
    I_loss_args={},  # Options for generator loss.
    resume_G_pkl=None,  # G network pickle to help training.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    I_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    I_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,  # Construct new networks according to I_args and M_args before resuming training?
    traversal_grid=False,  # Used for disentangled representation learning.
    n_discrete=0,  # Number of discrete latents in model.
    n_continuous=10,  # Number of continuous latents in model.
    n_samples_per=4,  # Number of samples for each line in traversal.
    use_hd_with_cls=False,  # If use info_loss.
    resolution_manual=1024,  # Resolution of generated images.
    use_level_training=False,  # If use level training (hierarchical optimization strategy).
    level_I_kimg=1000,  # Number of kimg of tick for I_level training.
    use_std_in_m=False,  # If output prior std in M net.
    prior_latent_size=512,  # Prior latent size.
    use_hyperplane=False,  # If use hyperplane model.
    pretrained_type='with_stylegan2'):  # Pretrained type for G.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            I = tflib.Network('I',
                              num_channels=training_set.shape[0],
                              resolution=resolution_manual,
                              label_size=training_set.label_size,
                              **I_args)
            M = tflib.Network('M',
                              num_channels=training_set.shape[0],
                              resolution=resolution_manual,
                              label_size=training_set.label_size,
                              **M_args)
            Is = I.clone('Is')
            if use_hd_with_cls:
                I_info = tflib.Network('I_info',
                                       num_channels=training_set.shape[0],
                                       resolution=resolution_manual,
                                       label_size=training_set.label_size,
                                       **I_info_args)
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if use_hd_with_cls:
                rI, rM, rIs, rI_info = misc.load_pkl(resume_pkl)
            else:
                rI, rM, rIs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                I.copy_vars_from(rI)
                M.copy_vars_from(rM)
                Is.copy_vars_from(rIs)
                if use_hd_with_cls:
                    I_info.copy_vars_from(rI_info)
            else:
                I = rI
                M = rM
                Is = rIs
                if use_hd_with_cls:
                    I_info = rI_info

        print('Loading generator from "%s"...' % resume_G_pkl)
        if pretrained_type == 'with_stylegan2':
            rG, rD, rGs = misc.load_pkl(resume_G_pkl)
            G = rG
            D = rD
            Gs = rGs
        elif pretrained_type == 'with_cascadeVAE':
            rG = misc.load_pkl(resume_G_pkl)
            G = rG
            Gs = rG

    # Print layers and generate initial image snapshot.
    I.print_layers()
    M.print_layers()
    # pdb.set_trace()
    training_set_resolution_log2 = int(np.log2(resolution_manual))
    sched = training_schedule(
        cur_nimg=total_kimg * 1000,
        training_set_resolution_log2=training_set_resolution_log2,
        **sched_args)

    if not use_hyperplane:
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels)
        print('grid_size:', grid_size)
        print('grid_latents.shape:', grid_latents.shape)
        print('grid_labels.shape:', grid_labels.shape)
        if resolution_manual >= 256:
            grid_size = (grid_size[0], grid_size[1] // 5)
            grid_latents = grid_latents[:grid_latents.shape[0] // 5]
            grid_labels = grid_labels[:grid_labels.shape[0] // 5]
        prior_traj_latents = M.run(grid_latents,
                                   is_validation=True,
                                   minibatch_size=sched.minibatch_gpu)
        if use_std_in_m:
            prior_traj_latents = prior_traj_latents[:, :prior_latent_size]
    else:
        grid_size = (n_samples_per, n_continuous)
        grid_labels = np.tile(grid_labels[:1],
                              (n_continuous * n_samples_per, 1))
        latent_dirs = get_latent_dirs(n_continuous)
        prior_traj_latents = get_prior_traj_by_dirs(latent_dirs, M,
                                                    n_samples_per,
                                                    prior_latent_size,
                                                    grid_labels, sched)
        if resolution_manual >= 256:
            grid_size = (grid_size[0], grid_size[1] // 5)
            prior_traj_latents = prior_traj_latents[:prior_traj_latents.
                                                    shape[0] // 5]
            grid_labels = grid_labels[:grid_labels.shape[0] // 5]
    print('prior_traj_latents.shape:', prior_traj_latents.shape)
    # pdb.set_trace()
    prior_traj_latents_show = np.reshape(
        prior_traj_latents, [-1, n_samples_per, prior_latent_size])
    print_traj(prior_traj_latents_show)
    grid_fakes = Gs.run(prior_traj_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu,
                        randomize_noise=True,
                        normalize_latents=False)
    grid_fakes = add_outline(grid_fakes, width=1)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)
    if (n_continuous == 2) and (n_discrete == 0):
        n_per_line = 20
        ex_latent_value = 3
        prior_grid_latents, prior_grid_labels = get_2d_grid_latents(
            low=-ex_latent_value,
            high=ex_latent_value,
            n_per_line=n_per_line,
            grid_labels=grid_labels)
        grid_showing_fakes = Gs.run(prior_grid_latents,
                                    prior_grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu,
                                    randomize_noise=True,
                                    normalize_latents=False)
        grid_showing_fakes = add_outline(grid_showing_fakes, width=1)
        misc.save_image_grid(
            grid_showing_fakes,
            dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'),
            drange=drange_net,
            grid_size=[n_per_line, n_per_line])
        img_to_draw = Image.open(
            dnnlib.make_run_dir_path('fakes_init_2d_prior_grid.png'))
        img_to_draw = img_to_draw.convert('RGB')
        img_to_draw = draw_traj_on_prior_grid(img_to_draw,
                                              prior_traj_latents_show,
                                              ex_latent_value, n_per_line)
        img_to_draw.save(
            dnnlib.make_run_dir_path('fakes_init_2d_prior_grid_drawn.png'))

    if use_level_training:
        ending_level = training_set_resolution_log2 - 1
    else:
        ending_level = 1

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Is_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), I_smoothing_kimg *
                              1000.0) if I_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    I_opt_args = dict(I_opt_args)
    for args, reg_interval in [(I_opt_args, I_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio

    I_opts = []
    I_reg_opts = []
    for n_level in range(ending_level):
        I_opts.append(tflib.Optimizer(name='TrainI_%d' % n_level,
                                      **I_opt_args))
        I_reg_opts.append(
            tflib.Optimizer(name='RegI_%d' % n_level,
                            share=I_opts[-1],
                            **I_opt_args))

    # Build training graph for each GPU.
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of I and M.
            I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow')
            M_gpu = M if gpu == 0 else M.clone(M.name + '_shadow')
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if use_hd_with_cls:
                I_info_gpu = I_info if gpu == 0 else I_info.clone(I_info.name +
                                                                  '_shadow')

            # Evaluate loss functions.
            lod_assign_ops = []
            I_losses = []
            I_regs = []
            if 'lod' in I_gpu.vars:
                lod_assign_ops += [tf.assign(I_gpu.vars['lod'], lod_in)]
            if 'lod' in M_gpu.vars:
                lod_assign_ops += [tf.assign(M_gpu.vars['lod'], lod_in)]
            for n_level in range(ending_level):
                with tf.control_dependencies(lod_assign_ops):
                    with tf.name_scope('I_loss_%d' % n_level):
                        if use_hd_with_cls:
                            I_loss, I_reg = dnnlib.util.call_func_by_name(
                                I=I_gpu,
                                M=M_gpu,
                                G=G_gpu,
                                I_info=I_info_gpu,
                                opt=I_opts[n_level],
                                training_set=training_set,
                                minibatch_size=minibatch_gpu_in,
                                **I_loss_args)
                        else:
                            I_loss, I_reg = dnnlib.util.call_func_by_name(
                                I=I_gpu,
                                M=M_gpu,
                                G=G_gpu,
                                opt=I_opts[n_level],
                                n_levels=(n_level +
                                          1) if use_level_training else None,
                                training_set=training_set,
                                minibatch_size=minibatch_gpu_in,
                                **I_loss_args)
                        I_losses.append(I_loss)
                        I_regs.append(I_reg)

                # Register gradients.
                if not lazy_regularization:
                    if I_regs[n_level] is not None:
                        I_losses[n_level] += I_regs[n_level]
                else:
                    if I_regs[n_level] is not None:
                        I_reg_opts[n_level].register_gradients(
                            tf.reduce_mean(I_regs[n_level] * I_reg_interval),
                            I_gpu.trainables)

                if use_hd_with_cls:
                    MIIinfo_gpu_trainables = collections.OrderedDict(
                        list(M_gpu.trainables.items()) +
                        list(I_gpu.trainables.items()) +
                        list(I_info_gpu.trainables.items()))
                    I_opts[n_level].register_gradients(
                        tf.reduce_mean(I_losses[n_level]),
                        MIIinfo_gpu_trainables)
                else:
                    MI_gpu_trainables = collections.OrderedDict(
                        list(M_gpu.trainables.items()) +
                        list(I_gpu.trainables.items()))
                    I_opts[n_level].register_gradients(
                        tf.reduce_mean(I_losses[n_level]), MI_gpu_trainables)

    # Setup training ops.
    I_train_ops = []
    I_reg_ops = []
    for n_level in range(ending_level):
        I_train_ops.append(I_opts[n_level].apply_updates())
        I_reg_ops.append(I_reg_opts[n_level].apply_updates(allow_no_op=True))
    Is_update_op = Is.setup_as_moving_average_of(I, beta=Is_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        I.setup_weight_histograms()
        M.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        n_level = 0 if not use_level_training else min(
            cur_nimg // (level_I_kimg * 1000), training_set_resolution_log2 -
            2)
        # Choose training parameters and configure training ops.
        sched = training_schedule(
            cur_nimg=cur_nimg,
            training_set_resolution_log2=training_set_resolution_log2,
            **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                I_opts[n_level].reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.I_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_I_reg = (lazy_regularization
                         and running_mb_counter % I_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run(I_train_ops[n_level], feed_dict)
                if run_I_reg:
                    tflib.run(I_reg_ops[n_level], feed_dict)
                tflib.run([Is_update_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(I_train_ops[n_level], feed_dict)
                if run_I_reg:
                    for _round in rounds:
                        tflib.run(I_reg_ops[n_level], feed_dict)
                tflib.run(Is_update_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 100 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):

                if not use_hyperplane:
                    prior_traj_latents = M.run(
                        grid_latents,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
                    if use_std_in_m:
                        prior_traj_latents = prior_traj_latents[:, :
                                                                prior_latent_size]
                else:
                    prior_traj_latents = get_prior_traj_by_dirs(
                        latent_dirs, M, n_samples_per, prior_latent_size,
                        grid_labels, sched)
                    if resolution_manual >= 256:
                        prior_traj_latents = prior_traj_latents[:
                                                                prior_traj_latents
                                                                .shape[0] // 5]

                prior_traj_latents_show = np.reshape(
                    prior_traj_latents, [-1, n_samples_per, prior_latent_size])
                print_traj(prior_traj_latents_show)
                grid_fakes = Gs.run(prior_traj_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu,
                                    randomize_noise=True,
                                    normalize_latents=False)
                grid_fakes = add_outline(grid_fakes, width=1)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                if (n_continuous == 2) and (n_discrete == 0):
                    n_per_line = 20
                    ex_latent_value = 3
                    prior_grid_latents, prior_grid_labels = get_2d_grid_latents(
                        low=-ex_latent_value,
                        high=ex_latent_value,
                        n_per_line=n_per_line,
                        grid_labels=grid_labels)
                    grid_showing_fakes = Gs.run(
                        prior_grid_latents,
                        prior_grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu,
                        randomize_noise=True,
                        normalize_latents=False)
                    grid_showing_fakes = add_outline(grid_showing_fakes,
                                                     width=1)
                    misc.save_image_grid(grid_showing_fakes,
                                         dnnlib.make_run_dir_path(
                                             'fakes_2d_prior_grid%06d.png' %
                                             (cur_nimg // 1000)),
                                         drange=drange_net,
                                         grid_size=[n_per_line, n_per_line])
                    img_to_draw = Image.open(
                        dnnlib.make_run_dir_path(
                            'fakes_2d_prior_grid%06d.png' %
                            (cur_nimg // 1000)))
                    img_to_draw = img_to_draw.convert('RGB')
                    img_to_draw = draw_traj_on_prior_grid(
                        img_to_draw, prior_traj_latents_show, ex_latent_value,
                        n_per_line)
                    img_to_draw.save(
                        dnnlib.make_run_dir_path(
                            'fakes_2d_prior_grid_drawn%06d.png' %
                            (cur_nimg // 1000)))

            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((I, M, Is), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((I, M, Is), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 8
0
def training_auto_loop(
    Enc_args={},  # Options for encoder network.
    Dec_args={},  # Options for decoder network.
    opt_args={},  # Options for encoder optimizer.
    loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False
):  # Include weight histograms in the tfevents file?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, _ = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        print('Constructing networks...')
        Enc = tflib.Network('Encoder',
                            resolution=training_set.shape[1],
                            out_channels=16,
                            **Enc_args)
        Dec = tflib.Network('Decoder',
                            resolution=training_set.shape[1] // 4,
                            in_channels=16,
                            **Dec_args)

    # Print layers and generate initial image snapshot.
    Enc.print_layers()
    Dec.print_layers()
    sched = sched_args
    sched.tick_kimg = 4
    grid_codes = Enc.run((grid_reals / 127.5) - 1.0,
                         minibatch_size=sched.minibatch_gpu)
    grid_fakes = Dec.run(grid_codes, minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    opt_args = dict(opt_args)
    opt_args['minibatch_multiplier'] = minibatch_multiplier
    opt_args['learning_rate'] = lrate_in
    opt = tflib.Optimizer(name='TrainAuto', **opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of Enc and Dec.
            Enc_gpu = Enc if gpu == 0 else Enc.clone(Enc.name + '_shadow')
            Dec_gpu = Dec if gpu == 0 else Dec.clone(Dec.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, _ = process_reals(reals_write, labels_write, 0.0,
                                               False,
                                               training_set.dynamic_range,
                                               drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                reals_read = reals_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            with tf.name_scope('loss'):
                loss = dnnlib.util.call_func_by_name(Enc=Enc_gpu,
                                                     Dec=Dec_gpu,
                                                     opt=opt,
                                                     reals=reals_read,
                                                     **loss_args)

            # Register gradients.
            opt.register_gradients(
                tf.reduce_mean(loss),
                list(Enc_gpu.trainables.values()) +
                list(Dec_gpu.trainables.values()))

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    train_op = opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        Enc.setup_weight_histograms()
        Dec.setup_weight_histograms()

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run(data_fetch_op, feed_dict)
                tflib.run(train_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start()

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_codes = Enc.run((grid_reals / 127.5) - 1.0,
                                     minibatch_size=sched.minibatch_gpu)
                grid_fakes = Dec.run(grid_codes,
                                     minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((Enc, Dec), pkl)

            # Update summaries and RunContext.
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0.0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((Enc, Dec), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 9
0
def training_loop_vc2(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    I_args={},  # Options for infogan-head/vcgan-head network.
    I_info_args={},  # Options for infogan-head/vcgan-head network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    # G2_opt_args={},  # Options for generator2 optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    use_info_gan=False,  # Whether to use info-gan.
    use_vc_head=False,  # Whether to use vc-head.
    use_vc2_info_gan=False,  # Whether to use vc2 infogan.
    use_perdis=False,  # Whether use perceptual distance network.
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    # G2_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,  # Construct new networks according to G_args and D_args before resuming training?
    traversal_grid=False,  # Used for disentangled representation learning.
    n_discrete=3,  # Number of discrete latents in model.
    n_continuous=4,  # Number of continuous latents in model.
    return_atts=False,  # If return attention maps.
    return_I_atts=False,  # If return I_attention maps of vpex.
    avg_mv_for_I=False,  # If use average moving for I.
    opt_reset_ls=None,  # Reset lr list for gradual latents.
    topk_dims_to_show=20,  # Number of top disentant dimensions to show in a snapshot.
    cascade_alt_freq_k=1,  # Frequency in k for cascade_dim altering.
    n_samples_per=10):  # Number of samples for each line in traversal.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # If include I
    include_I = use_info_gan or use_vc_head

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    grid_fakes = add_outline(grid_reals, width=1)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            print('G_args:', G_args)
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            if include_I:
                I = tflib.Network('I',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **I_args)
                if avg_mv_for_I:
                    Is = I.clone('Is')
            elif use_perdis:
                DM = misc.load_pkl(
                    'http://d36zk2xti64re0.cloudfront.net/stylegan1/networks/metrics/vgg16_zhang_perceptual.pkl'
                )

            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if include_I:
                if avg_mv_for_I:
                    rG, rD, rI, rGs, rIs = misc.load_pkl(resume_pkl)
                else:
                    rG, rD, rI, rGs = misc.load_pkl(resume_pkl)
            else:
                rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                if include_I:
                    I.copy_vars_from(rI)
                    if avg_mv_for_I:
                        Is.copy_vars_from(rIs)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                if include_I:
                    I = rI
                    if avg_mv_for_I:
                        Is = rIs
                Gs = rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    if include_I:
        I.print_layers()
    if use_perdis:
        DM.print_layers()
    # pdb.set_trace()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    if traversal_grid:
        if topk_dims_to_show > 0:
            topk_dims = np.arange(min(topk_dims_to_show, n_continuous))
        else:
            topk_dims = np.arange(n_continuous)
        print('topk_dims_to_show:', topk_dims_to_show)
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels, topk_dims)
    else:
        grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    print('grid_size:', grid_size)
    print('grid_latents.shape:', grid_latents.shape)
    print('grid_labels.shape:', grid_labels.shape)
    # pdb.set_trace()
    if return_atts:
        grid_fakes, atts = get_return_v(
            Gs.run(grid_latents,
                   grid_labels,
                   is_validation=True,
                   minibatch_size=sched.minibatch_gpu,
                   randomize_noise=True,
                   return_atts=True,
                   resolution=training_set.shape[1]), 2)
        # atts: [b, n_latents, 1, res, res]
        atts = atts[:, topk_dims]
        save_atts(atts,
                  filename=dnnlib.make_run_dir_path('fakes_atts_init.png'),
                  grid_size=grid_size,
                  drange=[0, 1],
                  grid_fakes=grid_fakes,
                  n_samples_per=n_samples_per)
    else:
        grid_fakes = get_return_v(
            Gs.run(grid_latents,
                   grid_labels,
                   is_validation=True,
                   minibatch_size=sched.minibatch_gpu,
                   randomize_noise=True), 1)
    grid_fakes = add_outline(grid_fakes, width=1)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    if include_I and return_I_atts:
        if avg_mv_for_I:
            I_tmp = Is
        else:
            I_tmp = I
        _, atts = get_return_v(
            I_tmp.run(grid_fakes,
                      grid_fakes,
                      grid_latents,
                      is_validation=True,
                      minibatch_size=sched.minibatch_gpu,
                      return_atts=True,
                      resolution=training_set.shape[1]), 2)
        save_atts(atts,
                  filename=dnnlib.make_run_dir_path('fakes_I_atts_init.png'),
                  grid_size=grid_size,
                  drange=[0, 1],
                  grid_fakes=grid_fakes,
                  n_samples_per=n_samples_per)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0
        cascade_dim = tf.placeholder(tf.int32, name='cascade_dim', shape=[])

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)
    if use_vc2_info_gan:
        G2_opt = tflib.Optimizer(name='TrainG2', share=G_opt, **G_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            if include_I:
                I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow')
            if use_perdis:
                DM_gpu = DM if gpu == 0 else DM.clone(DM.name + '_shadow')
            else:
                DM_gpu = None

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    if include_I:
                        G_loss, G_reg = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            I=I_gpu,
                            DM=DM_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            cascade_dim=cascade_dim,
                            **G_loss_args)
                    else:
                        G_loss, G_reg = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            DM=DM_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)
                if use_vc2_info_gan:
                    with tf.name_scope('G2_loss'):
                        G2_loss, _ = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            opt=G2_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            is_G2_loss=True,
                            **G_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)
            if include_I:
                GI_gpu_trainables = collections.OrderedDict(
                    list(G_gpu.trainables.items()) +
                    list(I_gpu.trainables.items()))
                G_opt.register_gradients(tf.reduce_mean(G_loss),
                                         GI_gpu_trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
            elif use_vc2_info_gan:
                GD_gpu_trainables = collections.OrderedDict(
                    list(G_gpu.trainables.items()) +
                    list(D_gpu.trainables.items()))
                G_opt.register_gradients(tf.reduce_mean(G_loss),
                                         G_gpu.trainables)
                G2_opt.register_gradients(tf.reduce_mean(G2_loss),
                                          GD_gpu_trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
            else:
                G_opt.register_gradients(tf.reduce_mean(G_loss),
                                         G_gpu.trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    if avg_mv_for_I:
        Is_update_op = Is.setup_as_moving_average_of(I, beta=Gs_beta)
    if use_vc2_info_gan:
        G2_train_op = G2_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
        if include_I:
            I.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        # if opt_reset_ls is not None:
        # if cur_nimg in opt_reset_ls:
        # G_opt.reset_optimizer_state()
        # D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Calculate which cascade_dim is to use.
        cur_nimg_k = cur_nimg // int(cascade_alt_freq_k * 1000)
        sched_cascade_dim = cur_nimg_k % n_continuous

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            cascade_dim: sched_cascade_dim
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                if avg_mv_for_I:
                    tflib.run([D_train_op, Gs_update_op, Is_update_op],
                              feed_dict)
                else:
                    tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)
                if use_vc2_info_gan:
                    tflib.run(G2_train_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                if avg_mv_for_I:
                    tflib.run([Gs_update_op, Is_update_op], feed_dict)
                else:
                    tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)
                if use_vc2_info_gan:
                    for _round in rounds:
                        tflib.run(G2_train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                if include_I:
                    if avg_mv_for_I:
                        misc.save_pkl((G, D, I, Gs, Is), pkl)
                    else:
                        misc.save_pkl((G, D, I, Gs), pkl)
                else:
                    misc.save_pkl((G, D, Gs), pkl)
                met_outs = metrics.run(pkl,
                                       run_dir=dnnlib.make_run_dir_path(),
                                       data_dir=dnnlib.convert_path(data_dir),
                                       num_gpus=num_gpus,
                                       tf_config=tf_config,
                                       include_I=include_I,
                                       avg_mv_for_I=avg_mv_for_I,
                                       Gs_kwargs=dict(is_validation=True,
                                                      return_atts=False),
                                       mapping_nodup=True)
                if topk_dims_to_show > 0:
                    if 'tpl_per_dim' in met_outs:
                        avg_distance_per_dim = met_outs[
                            'tpl_per_dim']  # shape: (n_continuous)
                        topk_dims = np.argsort(
                            avg_distance_per_dim
                        )[::-1][:topk_dims_to_show]  # shape: (20)
                    else:
                        topk_dims = np.arange(
                            min(topk_dims_to_show, n_continuous))
                else:
                    topk_dims = np.arange(n_continuous)

            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                if traversal_grid:
                    grid_size, grid_latents, grid_labels = get_grid_latents(
                        n_discrete, n_continuous, n_samples_per, G,
                        grid_labels, topk_dims)
                else:
                    grid_latents = np.random.randn(np.prod(grid_size),
                                                   *G.input_shape[1:])

                if return_atts:
                    grid_fakes, atts = get_return_v(
                        Gs.run(grid_latents,
                               grid_labels,
                               is_validation=True,
                               minibatch_size=sched.minibatch_gpu,
                               randomize_noise=True,
                               return_atts=True,
                               resolution=training_set.shape[1]), 2)
                    # atts: [b, n_latents, 1, res, res]
                    atts = atts[:, topk_dims]
                    save_atts(atts,
                              filename=dnnlib.make_run_dir_path(
                                  'fakes_atts%06d.png' % (cur_nimg // 1000)),
                              grid_size=grid_size,
                              drange=[0, 1],
                              grid_fakes=grid_fakes,
                              n_samples_per=n_samples_per)
                else:
                    grid_fakes = get_return_v(
                        Gs.run(grid_latents,
                               grid_labels,
                               is_validation=True,
                               minibatch_size=sched.minibatch_gpu,
                               randomize_noise=True), 1)
                grid_fakes = add_outline(grid_fakes, width=1)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                if include_I and return_I_atts:
                    if avg_mv_for_I:
                        I_tmp = Is
                    else:
                        I_tmp = I
                    _, atts = get_return_v(
                        I_tmp.run(grid_fakes,
                                  grid_fakes,
                                  grid_latents,
                                  is_validation=True,
                                  minibatch_size=sched.minibatch_gpu,
                                  return_atts=True,
                                  resolution=training_set.shape[1]), 2)
                    atts = atts[:, topk_dims]
                    save_atts(atts,
                              filename=dnnlib.make_run_dir_path(
                                  'fakes_I_atts%06d.png' % (cur_nimg // 1000)),
                              grid_size=grid_size,
                              drange=[0, 1],
                              grid_fakes=grid_fakes,
                              n_samples_per=n_samples_per)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    if include_I:
        if avg_mv_for_I:
            misc.save_pkl((G, D, I, Gs, Is),
                          dnnlib.make_run_dir_path('network-final.pkl'))
        else:
            misc.save_pkl((G, D, I, Gs),
                          dnnlib.make_run_dir_path('network-final.pkl'))
    else:
        misc.save_pkl((G, D, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Esempio n. 10
0
def generate_images(network_pkl,
                    seeds,
                    truncation_psi,
                    data_dir=None,
                    dataset_name=None,
                    model=None):
    G_args = EasyDict(func_name='training.' + model + '.G_main')
    dataset_args = EasyDict(tfrecord_dir=dataset_name)
    G_args.fmap_base = 8 << 10
    tflib.init_tf()
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    print('Constructing networks...')
    Gs = tflib.Network('G',
                       num_channels=training_set.shape[0],
                       resolution=training_set.shape[1],
                       label_size=training_set.label_size,
                       **G_args)
    print('Loading networks from "%s"...' % network_pkl)
    _, _, _Gs = pretrained_networks.load_networks(network_pkl)
    Gs.copy_vars_from(_Gs)
    noise_vars = [
        var for name, var in Gs.components.synthesis.vars.items()
        if name.startswith('noise')
    ]

    Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' %
              (seed, seed_idx, len(seeds)))
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:])  # [minibatch, component]
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]
        images, x_v, n_v, m_v = Gs.run(
            z, None, **Gs_kwargs)  # [minibatch, height, width, channel]

        print(images.shape, n_v.shape, x_v.shape, m_v.shape)
        misc.convert_to_pil_image(images[0], drange=[-1, 1]).save(
            dnnlib.make_run_dir_path('seed%04d.png' % seed))
        misc.save_image_grid(adjust_range(n_v),
                             dnnlib.make_run_dir_path('seed%04d-nv.png' %
                                                      seed),
                             drange=[-1, 1])
        print(np.linalg.norm(x_v - m_v))
        misc.save_image_grid(adjust_range(x_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-xv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(m_v).transpose([1, 0, 2, 3]),
                             dnnlib.make_run_dir_path('seed%04d-mv.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(x_v, 'cat')),
                             dnnlib.make_run_dir_path('seed%04d-xvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ss')),
                             dnnlib.make_run_dir_path('seed%04d-mvs.png' %
                                                      seed),
                             drange=[-1, 1])
        misc.save_image_grid(adjust_range(clip(m_v, 'ffhq')),
                             dnnlib.make_run_dir_path('seed%04d-fmvs.png' %
                                                      seed),
                             drange=[-1, 1])
Esempio n. 11
0
def training_loop(
    G_args                  = {},       # Options for generator network.
    D_args                  = {},       # Options for discriminator network.
    G_opt_args              = {},       # Options for generator optimizer.
    D_opt_args              = {},       # Options for discriminator optimizer.
    G_loss_args             = {},       # Options for generator loss.
    D_loss_args             = {},       # Options for discriminator loss.
    dataset_args            = {},       # Options for dataset.load_dataset().
    sched_args              = {},       # Options for train.TrainingSchedule.
    grid_args               = {},       # Options for train.setup_snapshot_image_grid().
    metric_arg_list         = [],       # Options for MetricGroup.
    tf_config               = {},       # Options for tflib.init_tf().
    data_dir                = None,     # Directory to load datasets from.
    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.
    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.
    lazy_regularization     = True,     # Perform regularization as a separate training step?
    G_reg_interval          = 4,        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval          = 16,       # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    drange_net              = [0,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_pkl              = None,     # Network pickle to resume training from, None = train from scratch.
    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0,      # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets    = False):   # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir), verbose=True, **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
    misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    # Construct or load networks.
    training_set.configure(minibatch_size=sched_args.batch_size)
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
            start = 0
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets: G.copy_vars_from(rG); D.copy_vars_from(rD); Gs.copy_vars_from(rGs)
            else: G = rG; D = rD; Gs = rGs
            start = int(resume_pkl.split('-')[-1].split('.')[0]) // sched_args.batch_size

    # Print layers and generate initial image snapshot.
    G.print_layers(); D.print_layers()
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size)
    misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.png'), drange=drange_net, grid_size=grid_size)

    global_step = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(sched_args.lr, global_step, sched_args.decay_step,
                                               sched_args.decay_rate, staircase=sched_args.stair)
    add_global = global_step.assign_add(1)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)
    G_opt = tflib.Optimizer(name='TrainG', learning_rate=learning_rate, **G_opt_args)

    for gpu in range(num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            with tf.name_scope('DataFetch'):
                reals_read, labels_read = training_set.get_minibatch_tf()

                reals_read, labels_read = process_reals(reals_read, labels_read, mirror_augment,
                                                        training_set.dynamic_range, drange_net)

            with tf.name_scope('Loss'), tf.control_dependencies(None):
                loss, reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set,
                                                          minibatch_size=sched_args.batch_size, reals=reals_read,
                                                          labels=labels_read, **D_loss_args)
            with tf.control_dependencies([add_global]):
                G_opt.register_gradients(loss, G_gpu.trainables)
                D_opt.register_gradients(loss, D_gpu.trainables)

    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        loss_, _, _, lr_ = tflib.run([loss, G_train_op, D_train_op, learning_rate])
        cur_nimg += sched_args.batch_size * num_gpus
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched_args.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f '
                'sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f loss %-8.1f lr %-2.5f' % (
                    autosummary('Progress/tick', cur_tick),
                    autosummary('Progress/kimg', cur_nimg / 1000.0),
                    autosummary('Progress/minibatch', sched_args.batch_size),
                    dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                    autosummary('Timing/sec_per_tick', tick_time),
                    autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                    autosummary('Timing/maintenance_sec', maintenance_time),
                    autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2 ** 30),
                    autosummary('loss', loss_),
                    autosummary('lr', lr_)))

            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = G.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched_args.batch_size)
                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net, grid_size=grid_size)
            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=dnnlib.make_run_dir_path(), data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus, tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0.0, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

            # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

        # All done.
    summary_log.close()
    training_set.close()