Exemplo n.º 1
0
def create_initial_pkl(
    G_args                  = {},         # Options for generator network.
    D_args                  = {},         # Options for discriminator network.
    tf_config               = {},         # Options for tflib.init_tf().
    config_id               = "config-f", # config-f is the only one tested ...
    num_channels            = 3,          # number of channels, e.g. 3 for RGB
    resolution_h            = 1024,       # height dimension of real/fake images
    resolution_w            = 1024,       # height dimension of real/fake images 
    label_size              = 0,          # number of labels for a conditional model
    ):   

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)

    resolution = resolution_h # training_set.shape[1]

    # Construct or load networks.
    with tf.device('/gpu:0'):
        print('Constructing networks...')
        G = tflib.Network('G', num_channels=num_channels, resolution=resolution, label_size=label_size, **G_args)
        D = tflib.Network('D', num_channels=num_channels, resolution=resolution, label_size=label_size, **D_args)
        Gs = G.clone('Gs')

    # Print layers and generate initial image snapshot.
    G.print_layers(); D.print_layers()
    pkl = 'network-initial-%s-%sx%s-%s.pkl' % (config_id, resolution_w, resolution_h, label_size)
    misc.save_pkl((G, D, Gs), pkl)
    print("Saving",pkl)
Exemplo n.º 2
0
def create_model(data_shape, full=False, kwargs_in=None):
    init_res, resolution, res_log2 = calc_init_res(data_shape[1:])
    kwargs_out = dnnlib.EasyDict()
    kwargs_out.num_channels = data_shape[0]
    kwargs_out.label_size = 0
    if kwargs_in is not None:
        for k in list(kwargs_in.keys()):
            kwargs_out[k] = kwargs_in[k]
    kwargs_out.resolution = resolution
    kwargs_out.init_res = init_res
    if a.verbose is True:
        print(['%s: %s' % (kv[0], kv[1]) for kv in sorted(kwargs_out.items())])
    if full is True:
        G = tflib.Network('G',
                          func_name='training.networks_stylegan2.G_main',
                          **kwargs_out)
        D = tflib.Network('D',
                          func_name='training.networks_stylegan2.D_stylegan2',
                          **kwargs_out)
        Gs = G.clone('Gs')
    else:
        Gs = tflib.Network('Gs',
                           func_name='training.networks_stylegan2.G_main',
                           **kwargs_out)
        G = D = None
    return G, D, Gs, res_log2
def G_main_spatial_biased_dsp(
        latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
        labels_in,  # Second input: Conditioning labels [minibatch, label_size].
        is_training=False,  # Network is under training? Enables and disables specific features.
        is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
        return_dlatents=False,  # Return dlatents in addition to the images?
        is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
        components=dnnlib.EasyDict(
        ),  # Container for sub-networks. Retained between calls.
        mapping_func='G_mapping_spatial_biased_dsp',  # Build func name for the mapping network.
        synthesis_func='G_synthesis_spatial_biased_dsp',  # Build func name for the synthesis network.
        **kwargs):  # Arguments for sub-networks (mapping and synthesis).
    # Validate arguments.
    assert not is_training or not is_validation

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network(
            'G_spatial_biased_synthesis_dsp',
            func_name=globals()[synthesis_func],
            **kwargs)
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_spatial_biased_mapping_dsp',
                                           func_name=globals()[mapping_func],
                                           dlatent_broadcast=None,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in,
                                                 labels_in,
                                                 is_training=is_training,
                                                 **kwargs)
    dlatents = tf.cast(dlatents, tf.float32)

    # Evaluate synthesis network.
    deps = []
    if 'lod' in components.synthesis.vars:
        deps.append(tf.assign(components.synthesis.vars['lod'], lod_in))
    with tf.control_dependencies(deps):
        images_out = components.synthesis.get_output_for(
            dlatents,
            is_training=is_training,
            force_clean_graph=is_template_graph,
            **kwargs)

    # Return requested outputs.
    images_out = tf.identity(images_out, name='images_out')
    if return_dlatents:
        return images_out, dlatents
    return images_out
Exemplo n.º 4
0
def load_model(
        url='https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ',  # karras2019stylegan-ffhq-1024x1024.pkl
        session=None,
        cache_dir='cache'):
    session = session or tf.get_default_session()
    with session.as_default():
        with dnnlib.util.open_url(url, cache_dir=cache_dir) as f:
            _G, _D, _Gs = pickle.load(f)
        G = tflib.Network(_G.name, G_style, **_G.static_kwargs)
        G.copy_vars_from(_G)
        D = tflib.Network(_D.name, D_basic, **_D.static_kwargs)
        D.copy_vars_from(_D)
        Gs = tflib.Network(_Gs.name, G_style, **_Gs.static_kwargs)
        Gs.copy_vars_from(_Gs)
        return G, D, Gs
Exemplo n.º 5
0
def load_perceptual(
        url='https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2',  # vgg16_zhang_perceptual.pkl
        session=None,
        cache_dir='cache'):
    session = session or tf.get_default_session()
    with dnnlib.util.open_url(url, cache_dir=cache_dir) as f:
        _P = pickle.load(f)
    with session.as_default():
        P = tflib.Network(_P.name, lpips_network, **_P.static_kwargs)
        P.copy_vars_from(_P)
        return P
Exemplo n.º 6
0
def create_model(data_shape, full=False):
    init_res, resolution, res_log2 = calc_init_res(data_shape[1:])
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.resolution = resolution
    Gs_kwargs.init_res = init_res
    Gs_kwargs.num_channels = data_shape[0]
    Gs_kwargs.label_size = 0
    if full is True:
        G = tflib.Network('G',
                          func_name='training.networks_stylegan2.G_main',
                          **Gs_kwargs)
        D = tflib.Network('D',
                          func_name='training.networks_stylegan2.D_stylegan2',
                          **Gs_kwargs)
        Gs = G.clone('Gs')
    else:
        Gs = tflib.Network('Gs',
                           func_name='training.networks_stylegan2.G_main',
                           **Gs_kwargs)
        G = D = None
    return G, D, Gs, res_log2
Exemplo n.º 7
0
def Decoder_main(
        latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
        labels_in,  # Second input: Conditioning labels [minibatch, label_size].
        is_training=False,  # Network is under training? Enables and disables specific features.
        return_dlatents=False,  # Return dlatents in addition to the images?
        is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
        components=dnnlib.EasyDict(
        ),  # Container for sub-networks. Retained between calls.
        mapping_func='Decoder_mapping',  # Build func name for the mapping network.
        synthesis_func='Decoder_synthesis',  # Build func name for the synthesis network.
        **kwargs):
    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network(
            'G_synthesis', func_name=globals()[synthesis_func], **kwargs)
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=globals()[mapping_func],
                                           **kwargs)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in,
                                                 labels_in,
                                                 is_training=is_training,
                                                 **kwargs)
    dlatents = tf.cast(dlatents, tf.float32)

    images_out = components.synthesis.get_output_for(
        dlatents,
        is_training=is_training,
        force_clean_graph=is_template_graph,
        **kwargs)

    # Return requested outputs.
    images_out = tf.identity(images_out, name='images_out')
    if return_dlatents:
        return images_out, dlatents
    return images_out
Exemplo n.º 8
0
def get_Gs(opt):
    # Find and load the network checkpoints, 1-by-1:
    for exp_number, snapshot_kimg in zip(opt.models, opt.snapshot_kimgs):
        resume_pkl = find_model(exp_number)
        if not resume_pkl:
            if not exp_number.endswith('.pkl'):  # Look for a pkl in results directory
                results_dir = os.path.join(os.getcwd(), config.result_dir)
                resume_pkl = find_pkl(results_dir, int(exp_number), snapshot_kimg)
            else:
                resume_pkl = exp_number
        tflib.init_tf()
        _, _, _Gs = load_pkl(resume_pkl)
        nz = _Gs.input_shapes[0][1]
        Gs = tflib.Network(name='Gs', func_name='training.networks_progan.G_paper',
                           latent_size=nz, num_channels=3, resolution=128, label_size=0)
        Gs.copy_vars_from(_Gs)
        print(f'Visualizing pkl: {resume_pkl} with seed={opt.seed}')
        if nz < 12 and not opt.interpolate_pre_norm:
            print(f'Model {exp_number} uses a small z vector (nz={nz}); you might want to add '
                  f'--interpolate_pre_norm to your command.')
        yield Gs, nz
Exemplo n.º 9
0
def embed(batch_size, resolution, imgs, network, iteration, result_dir, seed=6600):
    tf.reset_default_graph()
    G_args = dnnlib.EasyDict(func_name='training.networks_stylegan2_alpha.G_main')
    G_args.fmap_base = 8 << 10
    print('Loading networks from "%s"...' % network)
    tflib.init_tf()
    G = tflib.Network('G', num_channels=3, resolution=128, **G_args)
    _, _, Gs = pretrained_networks.load_networks(network)
    G.copy_vars_from(Gs)
    img_in = tf.placeholder(tf.float32)
    opt = tf.train.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999, epsilon=1e-8)
    opt_T = tf.train.AdamOptimizer(learning_rate=0.002, beta1=0.9, beta2=0.999, epsilon=1e-8)
    noise_vars = [var for name, var in G.components.synthesis.vars.items() if name.startswith('noise')]
    alpha_vars = [var for name, var in G.components.synthesis.vars.items() if name.endswith('alpha')]
    alpha_evals = [alpha.eval() for alpha in alpha_vars]

    G_kwargs = dnnlib.EasyDict()
    G_kwargs.randomize_noise = False
    G_syn = G.components.synthesis

    rnd = np.random.RandomState(seed)
    dlatent_avg = [var for name, var in G.vars.items() if name.startswith('dlatent_avg')][0].eval()
    dlatent_avg = np.expand_dims(np.expand_dims(dlatent_avg, 0), 1)
    dlatent_avg = dlatent_avg.repeat(12, 1)
    dlatent = tf.get_variable('dlatent', dtype=tf.float32, initializer=tf.constant(dlatent_avg),
                              trainable=True)
    T = tf.get_variable('T', dtype=tf.float32, initializer=tf.constant(0.95))
    alpha_pre = [scale_alpha_exp(alpha_eval, T) for alpha_eval in alpha_evals]
    synth_img = G_syn.get_output_for(dlatent, is_training=False, alpha_pre=alpha_pre, **G_kwargs)
    # synth_img = (synth_img + 1.0) / 2.0

    with tf.variable_scope('mse_loss'):
        mse_loss = tf.reduce_mean(tf.square(img_in - synth_img))
    with tf.variable_scope('perceptual_loss'):
        vgg_in = tf.concat([img_in, synth_img], 0)
        tf.keras.backend.set_image_data_format('channels_first')
        vgg = tf.keras.applications.VGG16(include_top=False, input_tensor=vgg_in, input_shape=(3, 128, 128),
                                          weights='/gdata2/fengrl/metrics/vgg.h5',
                                          pooling=None)
        h1 = vgg.get_layer('block1_conv1').output
        h2 = vgg.get_layer('block1_conv2').output
        h3 = vgg.get_layer('block3_conv2').output
        h4 = vgg.get_layer('block4_conv2').output
        pcep_loss = tf.reduce_mean(tf.square(h1[0] - h1[1])) + tf.reduce_mean(tf.square(h2[0] - h2[1])) + \
                    tf.reduce_mean(tf.square(h3[0] - h3[1])) + tf.reduce_mean(tf.square(h4[0] - h4[1]))
    loss = 0.5 * mse_loss + 0.5 * pcep_loss
    with tf.control_dependencies([loss]):
        grads = tf.gradients(mse_loss, [dlatent, T])
        train_op1 = opt.apply_gradients(zip([grads[0]], [dlatent]))
        train_op2 = opt_T.apply_gradients(zip([grads[1]], [T]))
        train_op = tf.group(train_op1, train_op2)
    reset_opt = tf.variables_initializer(opt.variables()+opt_T.variables())
    reset_dl = tf.variables_initializer([dlatent, T])

    tflib.init_uninitialized_vars()
    # rnd = np.random.RandomState(seed)
    tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})  # [height, width]
    idx = 0
    metrics_l = []
    metrics_p = []
    metrics_m = []
    metrics_d = []
    T_list = []
    for img in imgs:
        img = np.expand_dims(img, 0)
        loss_list = []
        p_loss_list = []
        m_loss_list = []
        dl_list = []
        si_list = []
        # tflib.set_vars({alpha: alpha_np for alpha, alpha_np in zip(alpha_vars, alpha_evals)})
        tflib.run([reset_opt, reset_dl])
        for i in range(iteration):
            loss_, p_loss_, m_loss_, dl_, si_, t_, _ = tflib.run([loss, pcep_loss, mse_loss, dlatent, synth_img, T, train_op],
                                                             {img_in: img})
            loss_list.append(loss_)
            p_loss_list.append(p_loss_)
            m_loss_list.append(m_loss_)
            dl_loss_ = np.sum(np.square(dl_-dlatent_avg))
            dl_list.append(dl_loss_)
            if i % 500 == 0:
                si_list.append(si_)
            if i % 100 == 0:
                print('idx %d, Loss %f, mse %f, ppl %f, dl %f, t %f, step %d' % (idx, loss_, m_loss_, p_loss_, dl_loss_, t_, i))
        print('T: %f, loss: %f, ppl: %f, mse: %f, d: %f' % (t_,
                                                               loss_list[-1],
                                                               p_loss_list[-1],
                                                               m_loss_list[-1],
                                                               dl_list[-1]))
        metrics_l.append(loss_list[-1])
        metrics_p.append(p_loss_list[-1])
        metrics_m.append(m_loss_list[-1])
        metrics_d.append(dl_list[-1])
        T_list.append(t_)
        misc.save_image_grid(np.concatenate(si_list, 0), os.path.join(result_dir, 'si%d.png' % idx), drange=[-1, 1])
        misc.save_image_grid(si_list[-1], os.path.join(result_dir, 'sifinal%d.png' % idx),
                             drange=[-1, 1])
        with open(os.path.join(result_dir, 'metric_l%d.txt' % idx), 'w') as f:
            for l_ in loss_list:
                f.write(str(l_) + '\n')
        with open(os.path.join(result_dir, 'metric_p%d.txt' % idx), 'w') as f:
            for l_ in p_loss_list:
                f.write(str(l_) + '\n')
        with open(os.path.join(result_dir, 'metric_m%d.txt' % idx), 'w') as f:
            for l_ in m_loss_list:
                f.write(str(l_) + '\n')
        with open(os.path.join(result_dir, 'metric_d%d.txt' % idx), 'w') as f:
            for l_ in dl_list:
                f.write(str(l_) + '\n')
        idx += 1

    l_mean = np.mean(metrics_l)
    p_mean = np.mean(metrics_p)
    m_mean = np.mean(metrics_m)
    d_mean = np.mean(metrics_d)
    with open(os.path.join(result_dir, 'metric_lmpd.txt'), 'w') as f:
        f.write(str(alpha_evals)+'\n')
        for i in range(len(metrics_l)):
            f.write(str(T_list[i])+'    '+str(metrics_l[i])+'    '+str(metrics_m[i])+'    '+str(metrics_p[i])+'    '+str(metrics_d[i])+'\n')

    print('Overall metrics: loss_mean %f, ppl_mean %f, mse_mean %f, d_mean %f' % (l_mean, p_mean, m_mean, d_mean))
    with open(os.path.join(result_dir, 'mean_metrics.txt'), 'w') as f:
        f.write('loss %f\n' % l_mean)
        f.write('mse %f\n' % m_mean)
        f.write('ppl %f\n' % p_mean)
        f.write('dl %f\n' % d_mean)
Exemplo n.º 10
0
def G_main(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    latmask,  # mask for split-frame latents blending
    dconst,  # initial (const) layer displacement
    truncation_psi=0.5,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=None,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    return_dlatents=False,  # Return dlatents in addition to the images?
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    mapping_func='G_mapping',  # Build func name for the mapping network.
    synthesis_func='G_synthesis_stylegan2',  # Build func name for the synthesis network.
    **kwargs):  # Arguments for sub-networks (mapping and synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training:
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network(
            'G_synthesis', func_name=globals()[synthesis_func], **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=globals()[mapping_func],
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg',
                                  shape=[dlatent_size],
                                  initializer=tf.initializers.zeros(),
                                  trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in,
                                                 labels_in,
                                                 is_training=is_training,
                                                 **kwargs)
    dlatents = tf.cast(dlatents, tf.float32)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.variable_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(
                latents2, labels_in, is_training=is_training, **kwargs)
            dlatents2 = tf.cast(dlatents2, tf.float32)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            # original version
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            """ # Diff Augment version
            mixing_cutoff = tf.where_v2(
                tf.random_uniform([tf.shape(dlatents)[0]], 0.0, 1.0) < style_mixing_prob,
                tf.random_uniform([tf.shape(dlatents)[0]], 1, cur_layers, dtype=tf.int32),
                cur_layers[np.newaxis])[:, np.newaxis, np.newaxis]
            dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2)
            """

    # Apply truncation trick.
    if truncation_psi is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            layer_psi = np.ones(layer_idx.shape, dtype=np.float32)
            if truncation_cutoff is None:
                layer_psi *= truncation_psi
            else:
                layer_psi = tf.where(layer_idx < truncation_cutoff,
                                     layer_psi * truncation_psi, layer_psi)
            dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi)

    # Evaluate synthesis network.
    deps = []
    if 'lod' in components.synthesis.vars:
        deps.append(tf.assign(components.synthesis.vars['lod'], lod_in))
    with tf.control_dependencies(deps):
        images_out = components.synthesis.get_output_for(
            dlatents,
            latmask,
            dconst,
            is_training=is_training,
            force_clean_graph=is_template_graph,
            **kwargs)

    # Return requested outputs.
    images_out = tf.identity(images_out, name='images_out')
    if return_dlatents:
        return images_out, dlatents
    return images_out
Exemplo n.º 11
0
def training_loop(
        submit_config,
        #---------------------------------------------------------------
        # Modified by Deng et al.
        noise_dim=32,
        weight_args={},
        train_stage_args={},
        #---------------------------------------------------------------
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        G_loss_args={},  # Options for generator loss.
        D_loss_args={},  # Options for discriminator loss.
        dataset_args={},  # Options for dataset.load_dataset().
        sched_args={},  # Options for train.TrainingSchedule.
        grid_args={},  # Options for train.setup_snapshot_image_grid().
        metric_arg_list=[],  # Options for MetricGroup.
        tf_config={},  # Options for tflib.init_tf().
        G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
        D_repeats=1,  # How many times the discriminator is trained per G iteration.
        minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
        reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
        total_kimg=15000,  # Total length of the training, measured in thousands of real images.
        mirror_augment=True,  # Enable mirror augment?
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        resume_time=0.0,
        **_kwargs
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    PI = 3.1415927
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    # Create 3d face reconstruction block
    FaceRender = Face3D()

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            #---------------------------------------------------------------
            # Modified by Deng et al.
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              latent_size=254 + noise_dim,
                              **G_args)
            #---------------------------------------------------------------
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        resolution = tf.placeholder(tf.float32, name='resolution', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)

            #---------------------------------------------------------------
            # Modified by Deng et al.
            G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\
                G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\
                lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\
                drange_net=drange_net,lod_in=lod_in,**train_stage_args)
            #---------------------------------------------------------------

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    #---------------------------------------------------------------
    # Modified by Deng et al.
    restore_weights_and_initialize(train_stage_args)

    print('Setting up snapshot image grid...')
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)

    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3])
    grid_INPUTcoeff = z_to_lambda_mapping(grid_latents)
    grid_INPUTcoeff_w_t = tf.concat(
        [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1)
    with tf.name_scope('FaceRender'):
        grid_render_img, _, _, _ = FaceRender.Reconstruction_Block(
            grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False)
        grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2])
        grid_render_img = process_reals(grid_render_img, lod_in, False,
                                        training_set.dynamic_range, drange_net)

    grid_INPUTcoeff_, grid_renders = tflib.run(
        [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod})
    grid_noise = np.random.randn(np.prod(grid_size), 32)
    grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise],
                                             axis=1)

    grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)
    grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    #---------------------------------------------------------------

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch,
                        resolution: sched.resolution
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    resolution: sched.resolution
                })

            # print('iter')
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                #---------------------------------------------------------------
                # Modified by Deng et al.
                grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            #---------------------------------------------------------------

            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Exemplo n.º 12
0
def training_loop(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    AE_opt_args=None,  # Options for autoencoder optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    AE_loss_args=None,  # Options for autoencoder loss.
    dataset_args={},  # Options for dataset.load_dataset().
    dataset_args_eval={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    train_data_dir=None,  # Directory to load datasets from.
    eval_data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=True,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,
    resume_with_own_vars=False
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    print("Loading train set from %s..." % dataset_args.tfrecord_dir)
    training_set = dataset.load_dataset(
        data_dir=dnnlib.convert_path(train_data_dir),
        verbose=True,
        **dataset_args)
    print("Loading eval set from %s..." % dataset_args_eval.tfrecord_dir)
    eval_set = dataset.load_dataset(
        data_dir=dnnlib.convert_path(eval_data_dir),
        verbose=True,
        **dataset_args_eval)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    # Freeze Discriminator
    if D_args['freeze']:
        num_layers = np.log2(training_set.resolution) - 1
        layers = int(np.round(num_layers * 3. / 8.))
        scope = ['Output', 'scores_out']
        for layer in range(layers):
            scope += ['.*%d' % 2**layer]
            if 'train_scope' in D_args:
                scope[-1] += '.*%d' % D_args['train_scope']
        D_args['train_scope'] = scope

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is '' or resume_with_new_nets or resume_with_own_vars:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not '':
            print('Loading networks from "%s"...' % resume_pkl)
            rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                Gs = rGs

    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    # SVD stuff
    if 'syn_svd' in G_args or 'map_svd' in G_args:
        # Run graph to calculate SVD
        grid_latents_smol = grid_latents[:1]
        rho = np.array([1])
        grid_fakes = G.run(grid_latents_smol,
                           grid_labels,
                           rho,
                           is_validation=True)
        grid_fakes = Gs.run(grid_latents_smol,
                            grid_labels,
                            rho,
                            is_validation=True)
        load_d_fake = D.run(grid_reals[:1], rho, is_validation=True)
        with tf.device('/gpu:0'):
            # Create SVD-decomposed graph
            rG, rD, rGs = G, D, Gs
            G_lambda_mask = {
                var: np.ones(G.vars[var].shape[-1])
                for var in G.vars if 'SVD/s' in var
            }
            D_lambda_mask = {
                'D/' + var: np.ones(D.vars[var].shape[-1])
                for var in D.vars if 'SVD/s' in var
            }
            G_reduce_dims = {
                var: (0, int(Gs.vars[var].shape[-1]))
                for var in Gs.vars if 'SVD/s' in var
            }
            G_args['lambda_mask'] = G_lambda_mask
            G_args['reduce_dims'] = G_reduce_dims
            D_args['lambda_mask'] = D_lambda_mask

            # Create graph with no SVD operations
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=rG.input_shapes[1][1],
                              factorized=True,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=rD.input_shapes[1][1],
                              factorized=True,
                              **D_args)
            Gs = G.clone('Gs')

            grid_fakes = G.run(grid_latents_smol,
                               grid_labels,
                               rho,
                               is_validation=True,
                               minibatch_size=1)
            grid_fakes = Gs.run(grid_latents_smol,
                                grid_labels,
                                rho,
                                is_validation=True,
                                minibatch_size=1)

            G.copy_vars_from(rG)
            D.copy_vars_from(rD)
            Gs.copy_vars_from(rGs)

    # Reduce per-gpu minibatch size to fit in 16GB GPU memory
    if grid_reals.shape[2] >= 1024:
        sched_args.minibatch_gpu_base = 2
    print('Batch size', sched_args.minibatch_gpu_base)

    # Generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    rho = np.array([1])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        rho,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)
    if resume_pkl is not '':
        load_d_real = rD.run(grid_reals[:1], rho, is_validation=True)
        load_d_fake = rD.run(grid_fakes[:1], rho, is_validation=True)
        d_fake = D.run(grid_fakes[:1], rho, is_validation=True)
        d_real = D.run(grid_reals[:1], rho, is_validation=True)
        print('Factorized fake', d_fake, 'loaded fake', load_d_fake,
              'factorized real', d_real, 'loaded real', load_d_real)
        print('(should match)')
    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)
    if AE_opt_args is not None:
        AE_opt_args = dict(AE_opt_args)
        AE_opt_args['minibatch_multiplier'] = minibatch_multiplier
        AE_opt_args['learning_rate'] = lrate_in
        AE_opt = tflib.Optimizer(name='TrainAE', **AE_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    if G_loss_args['func_name'] == 'training.loss.G_l1':
                        G_loss_args['reals'] = reals_read
                    else:
                        G_loss, G_reg = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0

    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            ae_iter_mul = 10
            ae_rounds = range(0, sched.minibatch_size,
                              sched.minibatch_gpu * num_gpus * ae_iter_mul)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    _g_loss, _ = tflib.run([G_loss, G_train_op], feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                print('g loss', _g_loss)
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)

            if network_snapshot_ticks is not None and cur_tick % network_snapshot_ticks == 0 or done:
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(eval_data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config,
                            rho=rho)
            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
    eval_set.close()
Exemplo n.º 13
0
def training_loop(
    classifier_args={},  # Options for generator network.
    classifier_opt_args={},  # Options for generator optimizer.
    classifier_loss_args={},
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    network_snapshot_ticks=5,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False):

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        shuffle_mb=2 * 4096,
                                        **dataset_args)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        print('Constructing networks...')
        classifier = tflib.Network('classifier',
                                   num_channels=training_set.shape[0],
                                   resolution=training_set.shape[1],
                                   label_size=training_set.label_size,
                                   **classifier_args)

    classifier.print_layers()

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)

    # Setup optimizers.
    classifier_opt_args = dict(classifier_opt_args)

    classifier_opt_args['minibatch_multiplier'] = minibatch_multiplier
    classifier_opt_args['learning_rate'] = lrate_in

    classifier_opt = tflib.Optimizer(name='TrainClassifier',
                                     **classifier_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('gpu%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            classifier_gpu = classifier if gpu == 0 else classifier.clone(
                classifier.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=0, **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros(
                                             [sched.minibatch_gpu, 127]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            with tf.name_scope('classifier_loss'):
                classifier_loss, label = dnnlib.util.call_func_by_name(
                    classifier=classifier_gpu,
                    images=reals_read,
                    labels=labels_read,
                    **classifier_loss_args)

            classifier_opt.register_gradients(tf.reduce_mean(classifier_loss),
                                              classifier_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    classifier_train_op = classifier_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())

    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('', cur_epoch=0, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)

        # Run training ops.
        feed_dict = {
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([classifier_train_op, data_fetch_op], feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    classifier_loss_out, label_out, _ = tflib.run(
                        [classifier_loss, label, classifier_train_op],
                        feed_dict)
                    print_output = False
                    if print_output:
                        print('label')
                        print(np.round(label_out, 2))
                        print('loss')
                        print(np.round(classifier_loss_out, 2))
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start()

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl(classifier, pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % 0,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl(classifier, dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Exemplo n.º 14
0
def training_loop(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    loss_args={},  # Options for loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for metrics.
    metric_args={},  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    ema_start_kimg=None,  # Start of the exponential moving average. Default to the half-life period.
    G_ema_kimg=10,  # Half-life of the exponential moving average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=False,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=4,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=2,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=1,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False
):  # Construct new networks according to G_args and D_args before resuming training?

    if ema_start_kimg is None:
        ema_start_kimg = G_ema_kimg

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(verbose=True, **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            resume_networks = misc.load_pkl(resume_pkl)
            rG, rD, rGs = resume_networks
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                Gs.copy_vars_from(rGs)
            else:
                G, D, Gs = rG, rD, rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        G_lrate_in = tf.placeholder(tf.float32, name='G_lrate_in', shape=[])
        D_lrate_in = tf.placeholder(tf.float32, name='D_lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        run_D_reg_in = tf.placeholder(tf.bool, name='run_D_reg', shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta_mul_in = tf.placeholder(tf.float32,
                                        name='Gs_beta_in',
                                        shape=[])
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in, tf.float32),
                              G_ema_kimg * 1000.0) if G_ema_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    G_opt_args['learning_rate'] = G_lrate_in
    D_opt_args['learning_rate'] = D_lrate_in
    for args in [G_opt_args, D_opt_args]:
        args['minibatch_multiplier'] = minibatch_multiplier
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)

    # Build training graph for each GPU.
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            with tf.name_scope('DataFetch'):
                reals_read, labels_read = training_set.get_minibatch_tf()
                reals_read = process_reals(reals_read, lod_in, mirror_augment,
                                           training_set.dynamic_range,
                                           drange_net)

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('loss'):
                    G_loss, D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        real_labels=labels_read,
                        **loss_args)

            # Register gradients.
            if not lazy_regularization:
                if D_reg is not None:
                    D_loss += D_reg
            else:
                if D_reg is not None:
                    D_loss = tf.cond(run_D_reg_in,
                                     lambda: D_loss + D_reg * D_reg_interval,
                                     lambda: D_loss)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    Gs_update_op = Gs.setup_as_moving_average_of(G,
                                                 beta=Gs_beta * Gs_beta_mul_in)
    with tf.control_dependencies([Gs_update_op]):
        G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list, **metric_args)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop():
            break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            G_lrate_in: sched.G_lrate,
            D_lrate_in: sched.D_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu,
            Gs_beta_mul_in: 1 if cur_nimg >= ema_start_kimg * 1000 else 0,
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            feed_dict[run_D_reg_in] = run_D_reg
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            for _ in rounds:
                tflib.run(G_train_op, feed_dict)
                tflib.run(D_train_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()

run_id = 15
snapshot = 15326
G_args = {}
synthesis_kwargs = dict(output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True), minibatch_size=8)

tflib.init_tf()
# baseline model
# network_pkl = '../../results/00015-sgan-ffhq256-1gpu-baseline/network-snapshot-014526.pkl'
network_pkl = '../../results/00046-sgan-ffhq256-2gpu-adain-pixel-norm-continue/network-snapshot-012126.pkl'

# no noise model
# network_pkl = 'results/00022-sgan-ffhq256-2gpu/network-snapshot-005726.pkl'
_G, _D, Gs = misc.load_pkl(network_pkl)
G = tflib.Network('G', func_name='training.networks_stylegan_cutoff.G_style', num_channels=3, resolution=256,
                  label_size=0, structure='linear', **G_args)
G.copy_vars_from(Gs)

G_original = tflib.Network('G', func_name='training.networks_stylegan.G_style', num_channels=3, resolution=256,
                  label_size=0, structure='linear', **G_args)
G_original.copy_vars_from(Gs)

latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in [8])
images = G.run(latents, None, use_instance_norm = False, **synthesis_kwargs)
images_original = G_original.run(latents, None, use_instance_norm = False, **synthesis_kwargs)
print(images.shape)
fig, axs = plt.subplots(3, 3)
im = images[0]

counter = 20
for i in range(3):
Exemplo n.º 16
0
def G_style(
    latents_in,  # First input: Latent vectors (Z) [minibatch, latent_size].
    labels_in,  # Second input: Conditioning labels [minibatch, label_size].
    truncation_psi=0.7,  # Style strength multiplier for the truncation trick. None = disable.
    truncation_cutoff=8,  # Number of layers for which to apply the truncation trick. None = disable.
    truncation_psi_val=None,  # Value for truncation_psi to use during validation.
    truncation_cutoff_val=None,  # Value for truncation_cutoff to use during validation.
    dlatent_avg_beta=0.995,  # Decay for tracking the moving average of W during training. None = disable.
    style_mixing_prob=0.9,  # Probability of mixing styles during training. None = disable.
    is_training=False,  # Network is under training? Enables and disables specific features.
    is_validation=False,  # Network is under validation? Chooses which value to use for truncation_psi.
    is_template_graph=False,  # True = template graph constructed by the Network class, False = actual evaluation.
    components=dnnlib.EasyDict(
    ),  # Container for sub-networks. Retained between calls.
    **kwargs):  # Arguments for sub-networks (G_mapping and G_synthesis).

    # Validate arguments.
    assert not is_training or not is_validation
    assert isinstance(components, dnnlib.EasyDict)
    if is_validation:
        truncation_psi = truncation_psi_val
        truncation_cutoff = truncation_cutoff_val
    if is_training or (truncation_psi is not None
                       and not tflib.is_tf_expression(truncation_psi)
                       and truncation_psi == 1):
        truncation_psi = None
    if is_training or (truncation_cutoff is not None
                       and not tflib.is_tf_expression(truncation_cutoff)
                       and truncation_cutoff <= 0):
        truncation_cutoff = None
    if not is_training or (dlatent_avg_beta is not None
                           and not tflib.is_tf_expression(dlatent_avg_beta)
                           and dlatent_avg_beta == 1):
        dlatent_avg_beta = None
    if not is_training or (style_mixing_prob is not None
                           and not tflib.is_tf_expression(style_mixing_prob)
                           and style_mixing_prob <= 0):
        style_mixing_prob = None

    # Setup components.
    if 'synthesis' not in components:
        components.synthesis = tflib.Network('G_synthesis',
                                             func_name=G_synthesis,
                                             **kwargs)
    num_layers = components.synthesis.input_shape[1]
    dlatent_size = components.synthesis.input_shape[2]
    if 'mapping' not in components:
        components.mapping = tflib.Network('G_mapping',
                                           func_name=G_mapping,
                                           dlatent_broadcast=num_layers,
                                           **kwargs)

    # Setup variables.
    lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False)
    dlatent_avg = tf.get_variable('dlatent_avg',
                                  shape=[dlatent_size],
                                  initializer=tf.initializers.zeros(),
                                  trainable=False)

    # Evaluate mapping network.
    dlatents = components.mapping.get_output_for(latents_in, labels_in,
                                                 **kwargs)

    # Update moving average of W.
    if dlatent_avg_beta is not None:
        with tf.variable_scope('DlatentAvg'):
            batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0)
            update_op = tf.assign(
                dlatent_avg,
                tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta))
            with tf.control_dependencies([update_op]):
                dlatents = tf.identity(dlatents)

    # Perform style mixing regularization.
    if style_mixing_prob is not None:
        with tf.name_scope('StyleMix'):
            latents2 = tf.random_normal(tf.shape(latents_in))
            dlatents2 = components.mapping.get_output_for(
                latents2, labels_in, **kwargs)
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2
            mixing_cutoff = tf.cond(
                tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
                lambda: tf.random_uniform([], 1, cur_layers, dtype=tf.int32),
                lambda: cur_layers)
            dlatents = tf.where(
                tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)),
                dlatents, dlatents2)

    # Apply truncation trick.
    if truncation_psi is not None and truncation_cutoff is not None:
        with tf.variable_scope('Truncation'):
            layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis]
            ones = np.ones(layer_idx.shape, dtype=np.float32)
            coefs = tf.where(layer_idx < truncation_cutoff,
                             truncation_psi * ones, ones)
            dlatents = tflib.lerp(dlatent_avg, dlatents, coefs)

    # Evaluate synthesis network.
    with tf.control_dependencies(
        [tf.assign(components.synthesis.find_var('lod'), lod_in)]):
        images_out = components.synthesis.get_output_for(
            dlatents, force_clean_graph=is_template_graph, **kwargs)
    return tf.identity(images_out, name='images_out')
Exemplo n.º 17
0
def D_stylegan2(
        images_in,  # First input: Images [minibatch, channel, height, width].
        labels_in,  # Second input: Labels [minibatch, label_size].
        num_channels=3,  # Number of input color channels. Overridden based on dataset.
        resolution=1024,  # Input resolution. Overridden based on dataset.
        label_size=0,  # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
        fmap_base=16 <<
    10,  # Overall multiplier for the number of feature maps.
        fmap_decay=1.0,  # log2 feature map reduction when doubling the resolution.
        fmap_min=1,  # Minimum number of feature maps in any layer.
        fmap_max=512,  # Maximum number of feature maps in any layer.
        architecture='resnet',  # Architecture: 'orig', 'skip', 'resnet'.
        nonlinearity='lrelu',  # Activation function: 'relu', 'lrelu', etc.
        mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, 0 = disable.
        mbstd_num_features=1,  # Number of features for the minibatch standard deviation layer.
        dtype='float32',  # Data type to use for activations and outputs.
        resample_kernel=[
            1, 3, 3, 1
        ],  # Low-pass filter to apply when resampling activations. None = no filtering.
        mapping_label_func='D_mapping_label',
        components=dnnlib.EasyDict(
        ),  # Container for sub-networks. Retained between calls.
        dlabel_size=128,
        **_kwargs):  # Ignore unrecognized keyword args.

    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    def nf(stage):
        return np.clip(int(fmap_base / (2.0**(stage * fmap_decay))), fmap_min,
                       fmap_max)

    assert architecture in ['orig', 'skip', 'resnet']
    act = nonlinearity

    images_in.set_shape([None, num_channels, resolution, resolution])
    labels_in.set_shape([None, label_size])
    images_in = tf.cast(images_in, dtype)
    labels_in = tf.cast(labels_in, dtype)

    # dlabel = D_mapping_label(labels_in=labels_in, label_size=label_size)
    # dlabel = tf.cast(dlabel, dtype)

    if 'mapping_label' not in components:
        components.mapping_label = tflib.Network(
            'D_mapping_label',
            func_name=globals()[mapping_label_func],
            label_size=label_size,
            dlabel_size=dlabel_size)

    dlabel = components.mapping_label.get_output_for(labels_in)
    dlabel = tf.cast(dlabel, dtype)

    # Building blocks for main layers.
    def fromrgb(x, y, res):  # res = 2..resolution_log2
        with tf.variable_scope('FromRGB'):
            t = apply_bias_act(modulated_conv2d_layer(y,
                                                      dlabel,
                                                      fmaps=nf(res - 1),
                                                      kernel=1),
                               act=act)
            return t if x is None else x + t

    def block(x, res):  # res = 2..resolution_log2
        t = x
        with tf.variable_scope('Conv0'):
            x = apply_bias_act(modulated_conv2d_layer(x,
                                                      dlabel,
                                                      fmaps=nf(res - 1),
                                                      kernel=3),
                               act=act)
        with tf.variable_scope('Conv1_down'):
            x = apply_bias_act(modulated_conv2d_layer(
                x,
                dlabel,
                fmaps=nf(res - 2),
                kernel=3,
                down=True,
                resample_kernel=resample_kernel),
                               act=act)
        if architecture == 'resnet':
            with tf.variable_scope('Skip'):
                t = conv2d_layer(t,
                                 fmaps=nf(res - 2),
                                 kernel=1,
                                 down=True,
                                 resample_kernel=resample_kernel)
                x = (x + t) * (1 / np.sqrt(2))
        return x

    def downsample(y):
        with tf.variable_scope('Downsample'):
            return downsample_2d(y, k=resample_kernel)

    # Main layers.
    x = None
    y = images_in
    for res in range(resolution_log2, 2, -1):
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if architecture == 'skip' or res == resolution_log2:
                x = fromrgb(x, y, res)
            x = block(x, res)
            if architecture == 'skip':
                y = downsample(y)

    # Final layers.
    with tf.variable_scope('4x4'):
        if architecture == 'skip':
            x = fromrgb(x, y, 2)
        if mbstd_group_size > 1:
            with tf.variable_scope('MinibatchStddev'):
                x = minibatch_stddev_layer(x, mbstd_group_size,
                                           mbstd_num_features)
        with tf.variable_scope('Conv'):
            x = apply_bias_act(modulated_conv2d_layer(x,
                                                      dlabel,
                                                      fmaps=nf(1),
                                                      kernel=3),
                               act=act)
        with tf.variable_scope('Dense0'):
            x = apply_bias_act(dense_layer(x, fmaps=nf(0)), act=act)

    # Output layer with label conditioning from "Which Training Methods for GANs do actually Converge?"
    with tf.variable_scope('Output'):
        x = apply_bias_act(dense_layer(x, fmaps=max(labels_in.shape[1], 1)))
        if labels_in.shape[1] > 0:
            x = tf.reduce_sum(x * labels_in, axis=1, keepdims=True)
    scores_out = x

    # Output.
    assert scores_out.dtype == tf.as_dtype(dtype)
    scores_out = tf.identity(scores_out, name='scores_out')
    return scores_out
Exemplo n.º 18
0
def training_loop_refinement(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=True,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False
):  # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            _rG, _rD, rGs = misc.load_pkl(resume_pkl)
            del _rD, _rG
            if resume_with_new_nets:
                G.copy_vars_from(rGs)
                Gs.copy_vars_from(rGs)
                del rGs
            else:
                G = rG
                Gs = rGs

    # Set constant noise input for both G and Gs
    if G_args.get("randomize_noise", None) == False:
        noise_vars = [
            var for name, var in G.components.synthesis.vars.items()
            if name.startswith('noise')
        ]
        rnd = np.random.RandomState(123)
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]

        noise_vars = [
            var for name, var in Gs.components.synthesis.vars.items()
            if name.startswith('noise')
        ]
        rnd = np.random.RandomState(123)
        tflib.set_vars(
            {var: rnd.randn(*var.shape.as_list())
             for var in noise_vars})  # [height, width]

    # TESTS
    # from PIL import Image
    # reals, latents = training_set.get_minibatch_np(4)
    # reals = np.transpose(reals, [0, 2, 3, 1])
    # Image.fromarray(reals[0], 'RGB').save("test_reals.png")

    # labels = training_set.get_random_labels_np(4)
    # Gs_kwargs = dnnlib.EasyDict()
    # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    # fakes = Gs.run(latents, labels, minibatch_size=4, **Gs_kwargs)
    # Image.fromarray(fakes[0], 'RGB').save("test_fakes_Gs_new.png")
    # fakes = G.run(latents, labels, minibatch_size=4, **Gs_kwargs)
    # Image.fromarray(fakes[0], 'RGB').save("test_fakes_G_new.png")

    # Print layers and generate initial image snapshot.
    G.print_layers()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)

    # Freeze layers
    G_args.freeze_layers = list(G_args.get("freeze_layers", []))

    def freeze_vars(gen, verbose=True):
        assert len(G_args.freeze_layers) > 0
        for name in list(gen.trainables.keys()):
            if any(layer in name for layer in G_args.freeze_layers):
                del gen.trainables[name]
                if verbose: print(f"Freezed {name}")

    # Build training graph for each GPU.
    data_fetch_ops = []
    loss_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            if G_args.freeze_layers: freeze_vars(G_gpu, verbose=False)

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=None,
                        opt=G_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        latents=labels_read,
                        **G_loss_args)
                    loss_ops.append(G_loss)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    loss_op = tf.reduce_mean(tf.concat(loss_ops, axis=0))
    G_train_op = G_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    loss_per_batch_sum = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        tflib.run(data_fetch_op, feed_dict)
        ### TEST
        # fakes = G.get_output_for(labels_read, training_set.get_random_labels_tf(minibatch_gpu_in), is_training=True) # this is without activation in ~[-1.5, 1.5]
        # fakes = tf.clip_by_value(fakes, drange_net[0], drange_net[1])
        # reals = reals_read
        ### TEST
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                loss, _ = tflib.run([loss_op, G_train_op], feed_dict)
                # (loss, reals, fakes), _ = tflib.run([loss_op, G_train_op], feed_dict)
                tflib.run([data_fetch_op], feed_dict)
                # print(f"loss_tf  {np.mean(loss)}")
                # print(f"loss_np  {np.mean(np.square(reals - fakes))}")
                # print(f"loss_abs {np.mean(np.abs(reals - fakes))}")

                loss_per_batch_sum += loss
                #### TEST ####
                # if cur_nimg == sched.minibatch_size or cur_nimg % 2048 == 0:
                #     from PIL import Image
                #     reals = np.transpose(reals, [0, 2, 3, 1])
                #     fakes = np.transpose(fakes, [0, 2, 3, 1])
                #     diff = np.abs(reals - fakes)
                #     print(diff.min(), diff.max())
                #     for idx, (fake, real) in enumerate(zip(fakes, reals)):
                #         fake -= fake.min()
                #         fake /= fake.max()
                #         fake *= 255
                #         fake = fake.astype(np.uint8)
                #         Image.fromarray(fake, 'RGB').save(f"fake_loss_{idx}.png")
                #         real -= real.min()
                #         real /= real.max()
                #         real *= 255
                #         real = real.astype(np.uint8)
                #         Image.fromarray(real, 'RGB').save(f"real_loss_{idx}.png")
                ####
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([Gs_update_op], feed_dict)

            # Slow path with gradient accumulation. FIXME: Probably wrong
            else:
                for _round in rounds:
                    loss, _, _ = tflib.run(
                        [loss_op, G_train_op, data_fetch_op], feed_dict)
                    loss_per_batch_sum += loss / len(rounds)
                    if run_G_reg:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time
            tick_loss = loss_per_batch_sum * sched.minibatch_size / (
                tick_kimg * 1000)
            loss_per_batch_sum = 0

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d loss/px %-12.8f time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   autosummary('Progress/loss_per_px', tick_loss),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                misc.save_pkl((G, None, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, None, Gs), dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Exemplo n.º 19
0
import cv2

G_args = {}
synthesis_kwargs = dict(output_transform=dict(
    func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
                        minibatch_size=8)

tflib.init_tf()
# baseline model
baseline_network_pkl = '../results/00015-sgan-ffhq256-1gpu-baseline/network-snapshot-014526.pkl'

_G, _D, Gs = misc.load_pkl(baseline_network_pkl)
G_baseline = tflib.Network(
    'G',
    func_name='training.networks_stylegan_cutoff.G_style',
    num_channels=3,
    resolution=256,
    label_size=0,
    structure='linear',
    **G_args)
G_baseline.copy_vars_from(Gs)

without_progan_network_pkl = '../results/00001-sgan-ffhq256-2gpu-remove-progan/network-snapshot-014800.pkl'

_G, _D, Gs = misc.load_pkl(without_progan_network_pkl)
G_without_noise = tflib.Network(
    'G',
    func_name='training.networks_stylegan_cutoff.G_style',
    num_channels=3,
    resolution=256,
    label_size=0,
    structure='linear',
Exemplo n.º 20
0
def load_from(name, cfg):
    dnnlib.tflib.init_tf()
    with open(name, 'rb') as f:
        m = pickle.load(f)

    Gs = m[2]

    Gs_ = tflib.Network(
        'G',
        func_name='stylegan.training.networks_stylegan.G_style',
        num_channels=3,
        resolution=1024)

    Gs_.copy_vars_from(Gs)

    model = Model(
        startf=cfg.MODEL.START_CHANNEL_COUNT,
        layer_count=cfg.MODEL.LAYER_COUNT,
        maxf=cfg.MODEL.MAX_CHANNEL_COUNT,
        latent_size=cfg.MODEL.LATENT_SPACE_SIZE,
        mapping_layers=cfg.MODEL.MAPPING_LAYERS,
        truncation_psi=0.7,  #cfg.MODEL.TRUNCATIOM_PSI,
        truncation_cutoff=cfg.MODEL.TRUNCATIOM_CUTOFF,
        channels=3)

    def tensor(x, transpose=None):
        x = Gs.vars[x].eval()
        if transpose:
            x = np.transpose(x, transpose)
        return torch.tensor(x)

    for i in range(cfg.MODEL.MAPPING_LAYERS):
        block = getattr(model.mapping, "block_%d" % (i + 1))
        block.fc.weight[:] = tensor('G_mapping/Dense%d/weight' % i,
                                    (1, 0)) * block.fc.std
        block.fc.bias[:] = tensor(
            'G_mapping/Dense%d/bias' % i) * block.fc.lrmul

    model.dlatent_avg.buff[:] = tensor('dlatent_avg')
    model.generator.const[:] = tensor('G_synthesis/4x4/Const/const')

    for i in range(model.generator.layer_count):
        j = model.generator.layer_count - i - 1
        name = '%dx%d' % (2**(2 + i), 2**(2 + i))
        block = model.generator.decode_block[i]

        prefix = 'G_synthesis/%s' % name

        if not block.has_first_conv:
            prefix_1 = '%s/Const' % prefix
            prefix_2 = '%s/Conv' % prefix
        else:
            prefix_1 = '%s/Conv0_up' % prefix
            prefix_2 = '%s/Conv1' % prefix

        block.noise_weight_1[0, :, 0, 0] = tensor('%s/Noise/weight' % prefix_1)
        block.noise_weight_2[0, :, 0, 0] = tensor('%s/Noise/weight' % prefix_2)

        if block.has_first_conv:
            if block.fused_scale:
                block.conv_1.weight[:] = tensor(
                    '%s/weight' % prefix_1, (2, 3, 0, 1)) * block.conv_1.std
            else:
                block.conv_1.weight[:] = tensor(
                    '%s/weight' % prefix_1, (3, 2, 0, 1)) * block.conv_1.std

        block.conv_2.weight[:] = tensor('%s/weight' % prefix_2,
                                        (3, 2, 0, 1)) * block.conv_2.std
        block.bias_1[0, :, 0, 0] = tensor('%s/bias' % prefix_1)
        block.bias_2[0, :, 0, 0] = tensor('%s/bias' % prefix_2)
        block.style_1.weight[:] = tensor('%s/StyleMod/weight' % prefix_1,
                                         (1, 0)) * block.style_1.std
        block.style_1.bias[:] = tensor('%s/StyleMod/bias' % prefix_1)
        block.style_2.weight[:] = tensor('%s/StyleMod/weight' % prefix_2,
                                         (1, 0)) * block.style_2.std
        block.style_2.bias[:] = tensor('%s/StyleMod/bias' % prefix_2)

        model.generator.to_rgb[i].to_rgb.weight[:] = tensor(
            'G_synthesis/ToRGB_lod%d/weight' % (j),
            (3, 2, 0, 1)) * model.generator.to_rgb[i].to_rgb.std
        model.generator.to_rgb[i].to_rgb.bias[:] = tensor(
            'G_synthesis/ToRGB_lod%d/bias' % (j))

    return model, Gs_
Exemplo n.º 21
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)

    # setup generator
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_multi.G_main'
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.impl = a.ops

    # load model with arguments
    sess = tflib.init_tf({'allow_soft_placement': True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try:
        _, _, network = network
    except:
        pass
    for k in list(network.static_kwargs.keys()):
        Gs_kwargs[k] = network.static_kwargs[k]

    # reload custom network, if needed
    if '.pkl' in a.model.lower():
        print(' .. Gs from pkl ..', basename(a.model))
        Gs = network
    else:  # reconstruct network
        print(' .. Gs custom ..', basename(a.model))
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)

    z_dim = Gs.input_shape[1]
    dz_dim = 512  # dlatent_size
    try:
        dl_dim = 2 * (int(np.floor(np.log2(Gs_kwargs.resolution))) - 1)
    except:
        print(' Resave model, no resolution kwarg found!')
        exit(1)
    dlat_shape = (1, dl_dim, dz_dim)  # [1,18,512]

    # read saved latents
    if a.dlatents is not None and osp.isfile(a.dlatents):
        key_dlatents = load_latents(a.dlatents)
        if len(key_dlatents.shape) == 2:
            key_dlatents = np.expand_dims(key_dlatents, 0)
    elif a.dlatents is not None and osp.isdir(a.dlatents):
        # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1]
        key_dlatents = []
        npy_list = file_list(a.dlatents, 'npy')
        for npy in npy_list:
            key_dlatent = load_latents(npy)
            if len(key_dlatent.shape) == 2:
                key_dlatent = np.expand_dims(key_dlatent, 0)
            key_dlatents.append(key_dlatent)
        key_dlatents = np.concatenate(key_dlatents)  # [frm,18,512]
    else:
        print(' No input dlatents found')
        exit()
    key_dlatents = key_dlatents[:, np.newaxis]  # [frm,1,18,512]
    print(' key dlatents', key_dlatents.shape)

    # replace higher layers with single (style) latent
    if a.style_npy_file is not None:
        print(' styling with latent', a.style_npy_file)
        style_dlatent = load_latents(a.style_npy_file)
        while len(style_dlatent.shape) < 4:
            style_dlatent = np.expand_dims(style_dlatent, 0)
        # try replacing 5 by other value, less than dl_dim
        key_dlatents[:, :,
                     range(5, dl_dim), :] = style_dlatent[:, :,
                                                          range(5, dl_dim), :]

    frames = key_dlatents.shape[0] * a.fstep

    dlatents = latent_anima(dlat_shape,
                            frames,
                            a.fstep,
                            key_latents=key_dlatents,
                            cubic=a.cubic,
                            verbose=True)  # [frm,1,512]
    print(' dlatents', dlatents.shape)
    frame_count = dlatents.shape[0]

    # truncation trick
    dlatent_avg = Gs.get_var('dlatent_avg')  # (512,)
    tr_range = range(0, 8)
    dlatents[:, :, tr_range, :] = dlatent_avg + (dlatents[:, :, tr_range, :] -
                                                 dlatent_avg) * a.trunc

    # distort image by tweaking initial const layer
    if a.digress > 0:
        try:
            latent_size = Gs.static_kwargs['latent_size']
        except:
            latent_size = 512  # default latent size
        try:
            init_res = Gs.static_kwargs['init_res']
        except:
            init_res = (4, 4)  # default initial layer size
        dconst = a.digress * latent_anima([1, latent_size, *init_res],
                                          frames,
                                          a.fstep,
                                          cubic=True,
                                          verbose=False)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):

        if a.digress is True:
            tf.get_default_session().run(tf.assign(wvars[0], wts[i]))

        # generate multi-latent result
        if Gs.num_inputs == 2:
            output = Gs.components.synthesis.run(dlatents[i],
                                                 randomize_noise=False,
                                                 output_transform=fmt,
                                                 minibatch_size=1)
        else:
            output = Gs.components.synthesis.run(dlatents[i], [None],
                                                 dconst[i],
                                                 randomize_noise=False,
                                                 output_transform=fmt,
                                                 minibatch_size=1)

        ext = 'png' if output.shape[3] == 4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i, ext))
        imsave(filename, output[0])
        pbar.upd()
Exemplo n.º 22
0
def training_loop_vc(
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    I_args={},  # Options for infogan-head/vcgan-head network.
    I_info_args={},  # Options for infogan-head/vcgan-head network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    use_info_gan=False,  # Whether to use info-gan.
    use_vc_head=False,  # Whether to use vc-head.
    use_vc_head_with_cls=False,  # Whether to use classification in discriminator.
    data_dir=None,  # Directory to load datasets from.
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    lazy_regularization=True,  # Perform regularization as a separate training step?
    G_reg_interval=4,  # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval=16,  # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=25000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=50,  # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks=50,  # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_pkl=None,  # Network pickle to resume training from, None = train from scratch.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,  # Assumed wallclock time at the beginning. Affects reporting.
    resume_with_new_nets=False,  # Construct new networks according to G_args and D_args before resuming training?
    traversal_grid=False,  # Used for disentangled representation learning.
    n_discrete=3,  # Number of discrete latents in model.
    n_continuous=4,  # Number of continuous latents in model.
    n_samples_per=10):  # Number of samples for each line in traversal.

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(data_dir=dnnlib.convert_path(data_dir),
                                        verbose=True,
                                        **dataset_args)
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        training_set, **grid_args)
    misc.save_image_grid(grid_reals,
                         dnnlib.make_run_dir_path('reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            if use_info_gan or use_vc_head or use_vc_head_with_cls:
                I = tflib.Network('I',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **I_args)
                if use_vc_head_with_cls:
                    I_info = tflib.Network('I_info',
                                           num_channels=training_set.shape[0],
                                           resolution=training_set.shape[1],
                                           label_size=training_set.label_size,
                                           **I_info_args)

            Gs = G.clone('Gs')
        if resume_pkl is not None:
            print('Loading networks from "%s"...' % resume_pkl)
            if use_info_gan or use_vc_head:
                rG, rD, rI, rGs = misc.load_pkl(resume_pkl)
            elif use_vc_head_with_cls:
                rG, rD, rI, rI_info, rGs = misc.load_pkl(resume_pkl)
            else:
                rG, rD, rGs = misc.load_pkl(resume_pkl)
            if resume_with_new_nets:
                G.copy_vars_from(rG)
                D.copy_vars_from(rD)
                if use_info_gan or use_vc_head or use_vc_head_with_cls:
                    I.copy_vars_from(rI)
                    if use_vc_head_with_cls:
                        I_info.copy_vars_from(rI_info)
                Gs.copy_vars_from(rGs)
            else:
                G = rG
                D = rD
                if use_info_gan or use_vc_head or use_vc_head_with_cls:
                    I = rI
                    if use_vc_head_with_cls:
                        I_info = rI_info
                Gs = rGs

    # Print layers and generate initial image snapshot.
    G.print_layers()
    D.print_layers()
    if use_info_gan or use_vc_head or use_vc_head_with_cls:
        I.print_layers()
        if use_vc_head_with_cls:
            I_info.print_layers()
    # pdb.set_trace()
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)
    if traversal_grid:
        grid_size, grid_latents, grid_labels = get_grid_latents(
            n_discrete, n_continuous, n_samples_per, G, grid_labels)
    else:
        grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    print('grid_latents.shape:', grid_latents.shape)
    print('grid_labels.shape:', grid_labels.shape)
    # pdb.set_trace()
    grid_fakes, _ = Gs.run(grid_latents,
                           grid_labels,
                           is_validation=True,
                           minibatch_size=sched.minibatch_gpu,
                           randomize_noise=False)
    misc.save_image_grid(grid_fakes,
                         dnnlib.make_run_dir_path('fakes_init.png'),
                         drange=drange_net,
                         grid_size=grid_size)

    # Setup training inputs.
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in = tf.placeholder(tf.int32,
                                           name='minibatch_size_in',
                                           shape=[])
        minibatch_gpu_in = tf.placeholder(tf.int32,
                                          name='minibatch_gpu_in',
                                          shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in *
                                                     num_gpus)
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_size_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            if use_info_gan or use_vc_head or use_vc_head_with_cls:
                I_gpu = I if gpu == 0 else I.clone(I.name + '_shadow')
                if use_vc_head_with_cls:
                    I_info_gpu = I_info if gpu == 0 else I_info.clone(
                        I_info.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg * 1000),
                                          training_set=training_set,
                                          **sched_args)
                reals_var = tf.Variable(
                    name='reals',
                    trainable=False,
                    initial_value=tf.zeros([sched.minibatch_gpu] +
                                           training_set.shape))
                labels_var = tf.Variable(name='labels',
                                         trainable=False,
                                         initial_value=tf.zeros([
                                             sched.minibatch_gpu,
                                             training_set.label_size
                                         ]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(
                    reals_write, labels_write, lod_in, mirror_augment,
                    training_set.dynamic_range, drange_net)
                reals_write = tf.concat(
                    [reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat(
                    [labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars:
                lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars:
                lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    if use_info_gan or use_vc_head:
                        G_loss, G_reg, I_loss, _ = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            I=I_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                    elif use_vc_head_with_cls:
                        G_loss, G_reg, I_loss, I_info_loss = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            I=I_gpu,
                            I_info=I_info_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                    else:
                        G_loss, G_reg = dnnlib.util.call_func_by_name(
                            G=G_gpu,
                            D=D_gpu,
                            opt=G_opt,
                            training_set=training_set,
                            minibatch_size=minibatch_gpu_in,
                            **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(
                        G=G_gpu,
                        D=D_gpu,
                        opt=D_opt,
                        training_set=training_set,
                        minibatch_size=minibatch_gpu_in,
                        reals=reals_read,
                        labels=labels_read,
                        **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(G_reg * G_reg_interval),
                        G_gpu.trainables)
                if D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(D_reg * D_reg_interval),
                        D_gpu.trainables)
            # print('G_gpu.trainables:', G_gpu.trainables)
            # print('D_gpu.trainables:', D_gpu.trainables)
            # print('I_gpu.trainables:', I_gpu.trainables)
            if use_info_gan or use_vc_head:
                GI_gpu_trainables = collections.OrderedDict(
                    list(G_gpu.trainables.items()) +
                    list(I_gpu.trainables.items()))
                G_opt.register_gradients(tf.reduce_mean(G_loss + I_loss),
                                         GI_gpu_trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
                # G_opt.register_gradients(tf.reduce_mean(I_loss),
                # GI_gpu_trainables)
                # D_opt.register_gradients(tf.reduce_mean(I_loss),
                # D_gpu.trainables)
            elif use_vc_head_with_cls:
                GIIinfo_gpu_trainables = collections.OrderedDict(
                    list(G_gpu.trainables.items()) +
                    list(I_gpu.trainables.items()) +
                    list(I_info_gpu.trainables.items()))
                G_opt.register_gradients(
                    tf.reduce_mean(G_loss + I_loss + I_info_loss),
                    GIIinfo_gpu_trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)
            else:
                G_opt.register_gradients(tf.reduce_mean(G_loss),
                                         G_gpu.trainables)
                D_opt.register_gradients(tf.reduce_mean(D_loss),
                                         D_gpu.trainables)

            # if use_info_gan:
            # # INFO-GAN-HEAD loss
            # G_opt.register_gradients(tf.reduce_mean(I_loss),
            # G_gpu.trainables)
            # G_opt.register_gradients(tf.reduce_mean(I_loss),
            # I_gpu.trainables)
            # D_opt.register_gradients(tf.reduce_mean(I_loss),
            # D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
        if use_info_gan or use_vc_head or use_vc_head_with_cls:
            I.setup_weight_histograms()
            if use_vc_head_with_cls:
                I_info.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training for %d kimg...\n' % total_kimg)
    dnnlib.RunContext.get().update('',
                                   cur_epoch=resume_kimg,
                                   max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {
            lod_in: sched.lod,
            lrate_in: sched.G_lrate,
            minibatch_size_in: sched.minibatch_size,
            minibatch_gpu_in: sched.minibatch_gpu
        }
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size,
                           sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start(
            ) + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch_size),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes, _ = Gs.run(grid_latents,
                                       grid_labels,
                                       is_validation=True,
                                       minibatch_size=sched.minibatch_gpu,
                                       randomize_noise=False)
                misc.save_image_grid(grid_fakes,
                                     dnnlib.make_run_dir_path(
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('network-snapshot-%06d.pkl' %
                                               (cur_nimg // 1000))
                if use_info_gan or use_vc_head:
                    misc.save_pkl((G, D, I, Gs), pkl)
                elif use_vc_head_with_cls:
                    misc.save_pkl((G, D, I, I_info, Gs), pkl)
                else:
                    misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=dnnlib.make_run_dir_path(),
                            data_dir=dnnlib.convert_path(data_dir),
                            num_gpus=num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod,
                                           cur_epoch=cur_nimg // 1000,
                                           max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get(
            ).get_last_update_interval() - tick_time

    # Save final snapshot.
    if use_info_gan or use_vc_head:
        misc.save_pkl((G, D, I, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))
    elif use_vc_head_with_cls:
        misc.save_pkl((G, D, I, I_info, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))
    else:
        misc.save_pkl((G, D, Gs),
                      dnnlib.make_run_dir_path('network-final.pkl'))

    # All done.
    summary_log.close()
    training_set.close()
Exemplo n.º 23
0
def training_loop(
        run_dir='.',  # Output directory.
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        loss_args={},  # Options for loss function.
        train_dataset_args={},  # Options for dataset to train with.
        # Options for dataset to evaluate metrics against.
    metric_dataset_args={},
        augment_args={},  # Options for adaptive augmentations.
        metric_arg_list=[],  # Metrics to evaluate during training.
        num_gpus=1,  # Number of GPUs to use.
        minibatch_size=32,  # Global minibatch size.
        minibatch_gpu=4,  # Number of samples processed at a time by one GPU.
        # Half-life of the exponential moving average (EMA) of generator weights.
    G_smoothing_kimg=10,
        G_smoothing_rampup=None,  # EMA ramp-up coefficient.
        # Number of minibatches to run in the inner loop.
    minibatch_repeats=4,
        lazy_regularization=True,  # Perform regularization as a separate training step?
        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    G_reg_interval=4,
        # How often the perform regularization for D? Ignored if lazy_regularization=False.
        D_reg_interval=16,
        # Total length of the training, measured in thousands of real images.
        total_kimg=25000,
        kimg_per_tick=4,  # Progress snapshot interval.
        # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    image_snapshot_ticks=50,
        # How often to save network snapshots? None = only save 'networks-final.pkl'.
        network_snapshot_ticks=50,
        resume_pkl=None,  # Network pickle to resume training from.
        # Callback function for determining whether to abort training.
    abort_fn=None,
        progress_fn=None,  # Callback function for updating training progress.
):
    assert minibatch_size % (num_gpus * minibatch_gpu) == 0
    start_time = time.time()

    print('Loading training set...')
    training_set = dataset.load_dataset(**train_dataset_args)
    print('Image shape:', np.int32(training_set.shape).tolist())
    print('Label shape:', [training_set.label_size])
    print()

    print('Constructing networks...')
    with tf.device('/gpu:0'):
        G = tflib.Network('G',
                          num_channels=training_set.shape[0],
                          resolution=training_set.shape[1],
                          label_size=training_set.label_size,
                          **G_args)
        D = tflib.Network('D',
                          num_channels=training_set.shape[0],
                          resolution=training_set.shape[1],
                          label_size=training_set.label_size,
                          **D_args)
        Gs = G.clone('Gs')
        if resume_pkl is not None:
            print(f'Resuming from "{resume_pkl}"')
            with dnnlib.util.open_url(resume_pkl) as f:
                rG, rD, rGs = pickle.load(f)
            G.copy_vars_from(rG)
            D.copy_vars_from(rD)
            Gs.copy_vars_from(rGs)
    G.print_layers()
    D.print_layers()

    print('Exporting sample images...')
    grid_size, grid_reals, grid_labels = setup_snapshot_image_grid(
        training_set)
    save_image_grid(grid_reals,
                    os.path.join(run_dir, 'reals.png'),
                    drange=[0, 255],
                    grid_size=grid_size)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=minibatch_gpu)
    # save_image_grid(grid_fakes, os.path.join(
    #     run_dir, 'fakes_init.png'), drange=[-1, 1], grid_size=grid_size)

    print(f'Replicating networks across {num_gpus} GPUs...')
    G_gpus = [G]
    D_gpus = [D]
    for gpu in range(1, num_gpus):
        with tf.device(f'/gpu:{gpu}'):
            G_gpus.append(G.clone(f'{G.name}_gpu{gpu}'))
            D_gpus.append(D.clone(f'{D.name}_gpu{gpu}'))

    print('Initializing augmentations...')
    aug = None
    if augment_args.get('class_name', None) is not None:
        aug = dnnlib.util.construct_class_by_name(**augment_args)
        aug.init_validation_set(D_gpus=D_gpus, training_set=training_set)

    print('Setting up optimizers...')
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval),
                               (D_opt_args, D_reg_interval)]:
        args[
            'minibatch_multiplier'] = minibatch_size // num_gpus // minibatch_gpu
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args:
                args['beta1'] **= mb_ratio
            if 'beta2' in args:
                args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    print('Constructing training graph...')
    data_fetch_ops = []
    training_set.configure(minibatch_gpu)
    for gpu, (G_gpu, D_gpu) in enumerate(zip(G_gpus, D_gpus)):
        with tf.name_scope(f'Train_gpu{gpu}'), tf.device(f'/gpu:{gpu}'):

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                real_images_var = tf.Variable(
                    name='images',
                    trainable=False,
                    initial_value=tf.zeros([minibatch_gpu] +
                                           training_set.shape))
                real_labels_var = tf.Variable(name='labels',
                                              trainable=False,
                                              initial_value=tf.zeros([
                                                  minibatch_gpu,
                                                  training_set.label_size
                                              ]))
                real_images_write, real_labels_write = training_set.get_minibatch_tf(
                )
                real_images_write = tflib.convert_images_from_uint8(
                    real_images_write)
                data_fetch_ops += [
                    tf.assign(real_images_var, real_images_write)
                ]
                data_fetch_ops += [
                    tf.assign(real_labels_var, real_labels_write)
                ]

            # Evaluate loss function and register gradients.
            fake_labels = training_set.get_random_labels_tf(minibatch_gpu)
            terms = dnnlib.util.call_func_by_name(G=G_gpu,
                                                  D=D_gpu,
                                                  aug=aug,
                                                  fake_labels=fake_labels,
                                                  real_images=real_images_var,
                                                  real_labels=real_labels_var,
                                                  **loss_args)
            if lazy_regularization:
                if terms.G_reg is not None:
                    G_reg_opt.register_gradients(
                        tf.reduce_mean(terms.G_reg * G_reg_interval),
                        G_gpu.trainables)
                if terms.D_reg is not None:
                    D_reg_opt.register_gradients(
                        tf.reduce_mean(terms.D_reg * D_reg_interval),
                        D_gpu.trainables)
            else:
                if terms.G_reg is not None:
                    terms.G_loss += terms.G_reg
                if terms.D_reg is not None:
                    terms.D_loss += terms.D_reg
            G_opt.register_gradients(tf.reduce_mean(terms.G_loss),
                                     G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(terms.D_loss),
                                     D_gpu.trainables)

    print('Finalizing training ops...')
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_beta_in = tf.placeholder(tf.float32, name='Gs_beta_in', shape=[])
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta_in)
    tflib.init_uninitialized_vars()
    with tf.device('/gpu:0'):
        peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()

    print('Initializing metrics...')
    summary_log = tf.summary.FileWriter(run_dir)
    metrics = []
    for args in metric_arg_list:
        metric = dnnlib.util.construct_class_by_name(**args)
        metric.configure(dataset_args=metric_dataset_args, run_dir=run_dir)
        metrics.append(metric)

    print(f'Training for {total_kimg} kimg...')
    print()
    if progress_fn is not None:
        progress_fn(0, total_kimg)
    tick_start_time = time.time()
    maintenance_time = tick_start_time - start_time
    cur_nimg = 0
    cur_tick = -1
    tick_start_nimg = cur_nimg
    running_mb_counter = 0

    done = False
    while not done:

        # Compute EMA decay parameter.
        Gs_nimg = G_smoothing_kimg * 1000.0
        if G_smoothing_rampup is not None:
            Gs_nimg = min(Gs_nimg, cur_nimg * G_smoothing_rampup)
        Gs_beta = 0.5**(minibatch_size / max(Gs_nimg, 1e-8))

        # Run training ops.
        for _repeat_idx in range(minibatch_repeats):
            rounds = range(0, minibatch_size, minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization
                         and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization
                         and running_mb_counter % D_reg_interval == 0)
            cur_nimg += minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op])
                if run_G_reg:
                    tflib.run(G_reg_op)
                tflib.run([D_train_op, Gs_update_op], {Gs_beta_in: Gs_beta})
                if run_D_reg:
                    tflib.run(D_reg_op)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op)
                    if run_G_reg:
                        tflib.run(G_reg_op)
                tflib.run(Gs_update_op, {Gs_beta_in: Gs_beta})
                for _round in rounds:
                    tflib.run(data_fetch_op)
                    tflib.run(D_train_op)
                    if run_D_reg:
                        tflib.run(D_reg_op)

            # Run validation.
            if aug is not None:
                aug.run_validation(minibatch_size=minibatch_size)

        # Tune augmentation parameters.
        if aug is not None:
            aug.tune(minibatch_size * minibatch_repeats)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000) or (abort_fn is not None
                                                   and abort_fn())
        if done or cur_tick < 0 or cur_nimg >= tick_start_nimg + kimg_per_tick * 1000:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_end_time = time.time()
            total_time = tick_end_time - start_time
            tick_time = tick_end_time - tick_start_time

            # Report progress.
            print(' '.join([
                f"tick {autosummary('Progress/tick', cur_tick):<5d}",
                f"kimg {autosummary('Progress/kimg', cur_nimg / 1000.0):<8.1f}",
                f"time {dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)):<12s}",
                f"sec/tick {autosummary('Timing/sec_per_tick', tick_time):<7.1f}",
                f"sec/kimg {autosummary('Timing/sec_per_kimg', tick_time / tick_kimg):<7.2f}",
                f"maintenance {autosummary('Timing/maintenance_sec', maintenance_time):<6.1f}",
                f"gpumem {autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30):<5.1f}",
                f"augment {autosummary('Progress/augment', aug.strength if aug is not None else 0):.3f}",
            ]))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            if progress_fn is not None:
                progress_fn(cur_nimg // 1000, total_kimg)

            # Save snapshots.
            if image_snapshot_ticks is not None and (
                    done or cur_tick % image_snapshot_ticks == 0):
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=minibatch_gpu)
                save_image_grid(grid_fakes,
                                os.path.join(
                                    run_dir,
                                    f'fakes{cur_nimg // 1000:06d}.png'),
                                drange=[-1, 1],
                                grid_size=grid_size)
            if network_snapshot_ticks is not None and (
                    done or cur_tick % network_snapshot_ticks == 0):
                pkl = os.path.join(
                    run_dir, f'network-snapshot-{cur_nimg // 1000:06d}.pkl')
                with open(pkl, 'wb') as f:
                    pickle.dump((G, D, Gs), f)
                if len(metrics):
                    print('Evaluating metrics...')
                    for metric in metrics:
                        metric.run(pkl, num_gpus=num_gpus)

            # Update summaries.
            for metric in metrics:
                metric.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            tick_start_time = time.time()
            maintenance_time = tick_start_time - tick_end_time

    print()
    print('Exiting...')
    summary_log.close()
    training_set.close()
Exemplo n.º 24
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)
    np.random.seed(seed=696)

    # setup generator
    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.func_name = 'training.stylegan2_multi.G_main'
    Gs_kwargs.verbose = a.verbose
    Gs_kwargs.size = a.size
    Gs_kwargs.scale_type = a.scale_type
    Gs_kwargs.impl = a.ops
    
    # mask/blend latents with external latmask or by splitting the frame
    if a.latmask is None:
        nHW = [int(s) for s in a.nXY.split('-')][::-1]
        assert len(nHW)==2, ' Wrong count nXY: %d (must be 2)' % len(nHW)
        n_mult = nHW[0] * nHW[1]
        if a.verbose is True and n_mult > 1: print(' Latent blending w/split frame %d x %d' % (nHW[1], nHW[0]))
        lmask = np.tile(np.asarray([[[[None]]]]), (1,n_mult,1,1))
        Gs_kwargs.countHW = nHW
        Gs_kwargs.splitfine = a.splitfine
    else:
        if a.verbose is True: print(' Latent blending with mask', a.latmask)
        n_mult = 2
        if os.path.isfile(a.latmask): # single file
            lmask = np.asarray([[img_read(a.latmask)[:,:,0] / 255.]]) # [h,w]
        elif os.path.isdir(a.latmask): # directory with frame sequence
            lmask = np.asarray([[img_read(f)[:,:,0] / 255. for f in img_list(a.latmask)]]) # [h,w]
        else:
            print(' !! Blending mask not found:', a.latmask); exit(1)
        lmask = np.concatenate((lmask, 1 - lmask), 1) # [frm,2,h,w]
        Gs_kwargs.latmask_res = lmask.shape[2:]
    
    # load model with arguments
    sess = tflib.init_tf({'allow_soft_placement':True})
    pkl_name = osp.splitext(a.model)[0]
    with open(pkl_name + '.pkl', 'rb') as file:
        network = pickle.load(file, encoding='latin1')
    try: _, _, network = network
    except: pass
    for k in list(network.static_kwargs.keys()):
        Gs_kwargs[k] = network.static_kwargs[k]

    # reload custom network, if needed
    if '.pkl' in a.model.lower(): 
        print(' .. Gs from pkl ..', basename(a.model))
        Gs = network
    else: # reconstruct network
        print(' .. Gs custom ..', basename(a.model))
        # print(Gs_kwargs)
        Gs = tflib.Network('Gs', **Gs_kwargs)
        Gs.copy_vars_from(network)
    if a.verbose is True: print('kwargs:', ['%s: %s'%(kv[0],kv[1]) for kv in sorted(Gs.static_kwargs.items())])

    if a.verbose is True: print(' out shape', Gs.output_shape[1:])
    if a.size is None: a.size = Gs.output_shape[2:]

    if a.verbose is True: print(' making timeline..')
    lats = [] # list of [frm,1,512]
    for i in range(n_mult):
        lat_tmp = latent_anima((1, Gs.input_shape[1]), a.frames, a.fstep, cubic=a.cubic, gauss=a.gauss, verbose=False) # [frm,1,512]
        lats.append(lat_tmp) # list of [frm,1,512]
    latents = np.concatenate(lats, 1) # [frm,X,512]
    print(' latents', latents.shape)
    frame_count = latents.shape[0]
    
    # distort image by tweaking initial const layer
    if a.digress > 0:
        try: latent_size = Gs.static_kwargs['latent_size']
        except: latent_size = 512 # default latent size
        try: init_res = Gs.static_kwargs['init_res']
        except: init_res = (4,4) # default initial layer size 
        dconst = []
        for i in range(n_mult):
            dc_tmp = a.digress * latent_anima([1, latent_size, *init_res], a.frames, a.fstep, cubic=True, verbose=False)
            dconst.append(dc_tmp)
        dconst = np.concatenate(dconst, 1)
    else:
        dconst = np.zeros([frame_count, 1, 1, 1, 1])

    # labels / conditions
    try:
        label_size = Gs_kwargs.label_size
    except:
        label_size = 0
    if label_size > 0:
        labels = np.zeros((frame_count, n_mult, label_size)) # [frm,X,lbl]
        if a.labels is None:
            label_ids = []
            for i in range(n_mult):
                label_ids.append(random.randint(0, label_size-1))
        else:
            label_ids = [int(x) for x in a.labels.split('-')]
            label_ids = label_ids[:n_mult] # ensure we have enough labels
        for i, l in enumerate(label_ids):
            labels[:,i,l] = 1
    else:
        labels = [None]

    # generate images from latent timeline
    pbar = ProgressBar(frame_count)
    for i in range(frame_count):
    
        latent  = latents[i] # [X,512]
        label   = labels[i % len(labels)]
        latmask = lmask[i % len(lmask)] if lmask is not None else [None] # [X,h,w]
        dc      = dconst[i % len(dconst)] # [X,512,4,4]

        # generate multi-latent result
        if Gs.num_inputs == 2:
            output = Gs.run(latent, label, truncation_psi=a.trunc, randomize_noise=False, output_transform=fmt)
        else:
            output = Gs.run(latent, label, latmask, dc, truncation_psi=a.trunc, randomize_noise=False, output_transform=fmt)

        # save image
        ext = 'png' if output.shape[3]==4 else 'jpg'
        filename = osp.join(a.out_dir, "%06d.%s" % (i,ext))
        imsave(filename, output[0])
        pbar.upd()

    # convert latents to dlatents, save them
    if a.save_lat is True:
        latents = latents.squeeze(1) # [frm,512]
        dlatents = Gs.components.mapping.run(latents, label, dtype='float16') # [frm,18,512]
        filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1], a.size[0])
        filename = osp.join(osp.dirname(a.out_dir), filename)
        np.save(filename, dlatents)
        print('saved dlatents', dlatents.shape, 'to', filename)
Exemplo n.º 25
0
# Defining Input Placeholders
image_path = tf.placeholder(tf.string)
audio_path = tf.placeholder(tf.string)

# Loading High-Resolution Image, Downsampled Low-Resolution Image, Preprocessed Audio, Nearest-Neighbor Interpolation of Inout Low-Resolution Image
high_res_image, low_res_image, audio, low_res_image_nearest = load_test_sample(
    image_path, audio_path)

# Constructing All Encoders
with tf.device("/GPU:0"):
    tflib.init_tf()

    _, _, G = pickle.load(open(FLAGS.STYLEGAN_CHECKPOINT, "rb"))

    Gs = tflib.Network(name=G.name,
                       func_name="networks_stylegan.G_style",
                       **G.static_kwargs)

    with tf.variable_scope(LR_ENCODER_SCOPE, reuse=tf.AUTO_REUSE):
        encoded_input = LowResEncoder(
            input=low_res_image,
            num_channels=3,
            resolution=8,
            batch_size=FLAGS.BATCH_SIZE,
            num_scales=3,
            n_filters=128,
            output_feature_size=512,
        )

    with tf.variable_scope(AUDIO_ENCODER_SCOPE, reuse=tf.AUTO_REUSE):
        audio_encoded_input = SpectrogramEncoder(
Exemplo n.º 26
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    # noinspection PyTypeChecker
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            if noise2noise:
                meansq_error = tf.reduce_mean(
                    tf.square(noisy_target_split[gpu] - denoised))
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # ***********************************
    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:
            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.png".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec',
                                   (time_train / eval_interval) *
                                   (iteration_count - i))),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

        # Training epoch
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    # End of training
    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
Exemplo n.º 27
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=7000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        # Load pre-trained
        if resume_run_id is not None:
            if resume_run_id == 'latest':
                URL_FFHQ = 'https://s3-us-west-2.amazonaws.com/nanonets/blogs/karras2019stylegan-ffhq-1024x1024.pkl'
                tflib.init_tf()
                with dnnlib.util.open_url(URL_FFHQ,
                                          cache_dir=config.cache_dir) as f:
                    G, D, Gs = pickle.load(f)
                """
                network_pkl, resume_kimg = misc.locate_latest_pkl()
                print('Loading networks from "%s"...' % network_pkl)
                G, D, Gs = misc.load_pkl(network_pkl)
                """
            elif resume_run_id == 'restore_partial':
                print('Restore partially...')
                # Initialize networks
                G = tflib.Network('G',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **G_args)
                D = tflib.Network('D',
                                  num_channels=training_set.shape[0],
                                  resolution=training_set.shape[1],
                                  label_size=training_set.label_size,
                                  **D_args)
                Gs = G.clone('Gs')

                # Load pre-trained networks
                assert restore_partial_fn != None
                G_partial, D_partial, Gs_partial = pickle.load(
                    open(restore_partial_fn, 'rb'))

                # Restore (subset of) pre-trained weights
                # (only parameters that match both name and shape)
                G.copy_compatible_trainables_from(G_partial)
                D.copy_compatible_trainables_from(D_partial)
                Gs.copy_compatible_trainables_from(Gs_partial)

            else:
                network_pkl = misc.locate_network_pkl(resume_run_id,
                                                      resume_snapshot)
                print('Loading networks from "%s"...' % network_pkl)
                G, D, Gs = misc.load_pkl(network_pkl)

        # Start from scratch
        else:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Exemplo n.º 28
0
def training_loop(
        submit_config,
        Encoder_args={},
        E_opt_args={},
        D_opt_args={},
        E_loss_args={},
        D_loss_args={},
        lr_args=EasyDict(),
        tf_config={},
        dataset_args=EasyDict(),
        decoder_pkl=EasyDict(),
        drange_data=[0, 255],
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        mirror_augment=False,
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        max_iters=150000,
        E_smoothing=0.999):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        real_train = tf.placeholder(tf.float32, [
            submit_config.batch_size, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                    name='real_image_train')
        real_test = tf.placeholder(tf.float32, [
            submit_config.batch_size_test, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                   name='real_image_test')
        real_split = tf.split(real_train,
                              num_or_size_splits=submit_config.num_gpus,
                              axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs, NE = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')
                        [0]) // submit_config.batch_size
        else:
            print('Constructing networks...')
            G, D, Gs, NE = misc.load_pkl(decoder_pkl.decoder_pkl)
            E = tflib.Network('E',
                              size=submit_config.image_size,
                              filter=64,
                              filter_max=1024,
                              phase=True,
                              **Encoder_args)
            start = 0

    Gs.print_layers()
    E.print_layers()
    D.print_layers()

    global_step = tf.Variable(start,
                              trainable=False,
                              name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate,
                                               global_step,
                                               lr_args.decay_step,
                                               lr_args.decay_rate,
                                               staircase=lr_args.stair)
    add_global = global_step.assign_add(1)
    E_opt = tflib.Optimizer(name='TrainE',
                            learning_rate=learning_rate,
                            **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=learning_rate,
                            **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            perceptual_model = PerceptualModel(
                img_size=[submit_config.image_size, submit_config.image_size],
                multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment,
                                     drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(
                    E=E_gpu,
                    G=G_gpu,
                    D=D_gpu,
                    perceptual_model=perceptual_model,
                    reals=real_gpu,
                    **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(
                    E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    #Es_update_op = Es.setup_as_moving_average_of(E, beta=E_smoothing)

    print('building testing graph...')
    fake_X_val = test(E, Gs, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess,
                                       data_dir=dataset_args.data_train,
                                       submit_config=submit_config,
                                       mode='train')
    image_batch_test = get_train_data(sess,
                                      data_dir=dataset_args.data_test,
                                      submit_config=submit_config,
                                      mode='test')

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        E.setup_weight_histograms()
        D.setup_weight_histograms()

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    print('Optimization starts!!!')
    for it in range(start, max_iters):

        feed_dict = {real_train: sess.run(image_batch_train)}
        sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict)
        sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad],
                 feed_dict)

        cur_nimg += submit_config.batch_size

        if it % 100 == 0:
            print("Iter: %06d  kimg: %-8.1f time: %-12s" %
                  (it, cur_nimg / 1000,
                   dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg

            if cur_tick % image_snapshot_ticks == 0:
                batch_images_test = sess.run(image_batch_test)
                batch_images_test = misc.adjust_dynamic_range(
                    batch_images_test.astype(np.float32), [0, 255], [-1., 1.])

                samples2 = sess.run(fake_X_val,
                                    feed_dict={real_test: batch_images_test})
                samples2 = samples2.transpose(0, 2, 3, 1)
                batch_images_test = batch_images_test.transpose(0, 2, 3, 1)
                orin_recon = np.concatenate([batch_images_test, samples2],
                                            axis=0)
                imwrite(immerge(orin_recon, 2, submit_config.batch_size_test),
                        '%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg))

            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir,
                                   'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs, NE), pkl)

    misc.save_pkl((E, G, D, Gs, NE),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
Exemplo n.º 29
0
def training_loop(
    G_args                  = {},       # Options for generator network.
    D_args                  = {},       # Options for discriminator network.
    G_opt_args              = {},       # Options for generator optimizer.
    D_opt_args              = {},       # Options for discriminator optimizer.
    G_loss_args             = {},       # Options for generator loss.
    D_loss_args             = {},       # Options for discriminator loss.
    dataset_args            = {},       # Options for dataset.load_dataset().
    sched_args              = {},       # Options for train.TrainingSchedule.
    grid_args               = {},       # Options for train.setup_snapshot_image_grid().
    setname                 = None,   # Model name 
    tf_config               = {},       # Options for tflib.init_tf().
    G_smoothing_kimg        = 10.0,     # Half-life of the running average of generator weights.
    minibatch_repeats       = 4,        # Number of minibatches to run before adjusting training parameters.
    lazy_regularization     = True,     # Perform regularization as a separate training step?
    G_reg_interval          = 4,        # How often the perform regularization for G? Ignored if lazy_regularization=False.
    D_reg_interval          = 16,       # How often the perform regularization for D? Ignored if lazy_regularization=False.
    reset_opt_for_new_lod   = True,     # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg              = 25000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    mirror_augment_v        = False,  # Enable mirror augment vertically?
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 50,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    network_snapshot_ticks  = 50,       # How often to save network snapshots? None = only save 'networks-final.pkl'.
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_pkl              = 'latest',     # Network pickle to resume training from, None = train from scratch.
    resume_kimg             = 0.0,      # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time             = 0.0,      # Assumed wallclock time at the beginning. Affects reporting.
    restore_partial_fn      = None,   # Filename of network for partial restore
    resume_with_new_nets    = False):   # Construct new networks according to G_args and D_args before resuming training?

    # Initialize dnnlib and TensorFlow.
    tflib.init_tf(tf_config)
    num_gpus = dnnlib.submit_config.num_gpus

    # Load training set.
    training_set = dataset.load_dataset(verbose=True, **dataset_args)
    # custom resolution - for saved model name below
    resolution = training_set.resolution
    if training_set.init_res != [4,4]:
        init_res_str = '-%dx%d' % (training_set.init_res[0], training_set.init_res[1])
    else:
        init_res_str = ''
    ext = 'png' if training_set.shape[0] == 4 else 'jpg'
    print(' model base resolution', resolution)
    
    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(training_set, **grid_args)
    misc.save_image_grid(grid_reals, dnnlib.make_run_dir_path('_reals.%s'%ext), drange=training_set.dynamic_range, grid_size=grid_size)

    # Construct or load networks.
    with tf.device('/gpu:0'):
        if resume_pkl is None or resume_with_new_nets:
            print(' Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
        if resume_pkl is not None:
            if resume_pkl == 'latest':
                resume_pkl, resume_kimg = misc.locate_latest_pkl(dnnlib.submit_config.run_dir_root)
            elif resume_pkl == 'restore_partial':
                print(' Restore partially...')
                # Initialize networks
                G = tflib.Network('G', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **G_args)
                D = tflib.Network('D', num_channels=training_set.shape[0], resolution=resolution, label_size=training_set.label_size, **D_args)
                Gs = G.clone('Gs')
                # Load pre-trained networks
                assert restore_partial_fn != None
                G_partial, D_partial, Gs_partial = pickle.load(open(restore_partial_fn, 'rb'))
                # Restore (subset of) pre-trained weights (only parameters that match both name and shape)
                G.copy_compatible_trainables_from(G_partial)
                D.copy_compatible_trainables_from(D_partial)
                Gs.copy_compatible_trainables_from(Gs_partial)
            else:
                if resume_pkl is not None and resume_kimg == 0:
                    resume_pkl, resume_kimg = misc.locate_latest_pkl(resume_pkl)
                print(' Loading networks from "%s", kimg %.3g' % (resume_pkl, resume_kimg))
                rG, rD, rGs = misc.load_pkl(resume_pkl)
                if resume_with_new_nets:
                    G.copy_vars_from(rG)
                    D.copy_vars_from(rD)
                    Gs.copy_vars_from(rGs)
                else:
                    G, D, Gs = rG, rD, rGs
                
    # Print layers if needed and generate initial image snapshot
    # G.print_layers(); D.print_layers()
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, **sched_args)
    grid_latents = np.random.randn(np.prod(grid_size), *G.input_shape[1:])
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
    misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fakes_init.%s'%ext), drange=drange_net, grid_size=grid_size)

    # Setup training inputs.
    print(' Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in               = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in             = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_size_in    = tf.placeholder(tf.int32, name='minibatch_size_in', shape=[])
        minibatch_gpu_in     = tf.placeholder(tf.int32, name='minibatch_gpu_in', shape=[])
        minibatch_multiplier = minibatch_size_in // (minibatch_gpu_in * num_gpus)
        Gs_beta              = 0.5 ** tf.div(tf.cast(minibatch_size_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # Setup optimizers.
    G_opt_args = dict(G_opt_args)
    D_opt_args = dict(D_opt_args)
    for args, reg_interval in [(G_opt_args, G_reg_interval), (D_opt_args, D_reg_interval)]:
        args['minibatch_multiplier'] = minibatch_multiplier
        args['learning_rate'] = lrate_in
        if lazy_regularization:
            mb_ratio = reg_interval / (reg_interval + 1)
            args['learning_rate'] *= mb_ratio
            if 'beta1' in args: args['beta1'] **= mb_ratio
            if 'beta2' in args: args['beta2'] **= mb_ratio
    G_opt = tflib.Optimizer(name='TrainG', **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', **D_opt_args)
    G_reg_opt = tflib.Optimizer(name='RegG', share=G_opt, **G_opt_args)
    D_reg_opt = tflib.Optimizer(name='RegD', share=D_opt, **D_opt_args)

    # Build training graph for each GPU.
    data_fetch_ops = []
    for gpu in range(num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):

            # Create GPU-specific shadow copies of G and D.
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')

            # Fetch training data via temporary variables.
            with tf.name_scope('DataFetch'):
                sched = training_schedule(cur_nimg=int(resume_kimg*1000), training_set=training_set, **sched_args)
                reals_var = tf.Variable(name='reals', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu] + training_set.shape))
                labels_var = tf.Variable(name='labels', trainable=False, initial_value=tf.zeros([sched.minibatch_gpu, training_set.label_size]))
                reals_write, labels_write = training_set.get_minibatch_tf()
                reals_write, labels_write = process_reals(reals_write, labels_write, lod_in, mirror_augment, mirror_augment_v, training_set.dynamic_range, drange_net)
                reals_write = tf.concat([reals_write, reals_var[minibatch_gpu_in:]], axis=0)
                labels_write = tf.concat([labels_write, labels_var[minibatch_gpu_in:]], axis=0)
                data_fetch_ops += [tf.assign(reals_var, reals_write)]
                data_fetch_ops += [tf.assign(labels_var, labels_write)]
                reals_read = reals_var[:minibatch_gpu_in]
                labels_read = labels_var[:minibatch_gpu_in]

            # Evaluate loss functions.
            lod_assign_ops = []
            if 'lod' in G_gpu.vars: lod_assign_ops += [tf.assign(G_gpu.vars['lod'], lod_in)]
            if 'lod' in D_gpu.vars: lod_assign_ops += [tf.assign(D_gpu.vars['lod'], lod_in)]
            with tf.control_dependencies(lod_assign_ops):
                with tf.name_scope('G_loss'):
                    G_loss, G_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, **G_loss_args)
                with tf.name_scope('D_loss'):
                    D_loss, D_reg = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_gpu_in, reals=reals_read, labels=labels_read, **D_loss_args)

            # Register gradients.
            if not lazy_regularization:
                if G_reg is not None: G_loss += G_reg
                if D_reg is not None: D_loss += D_reg
            else:
                if G_reg is not None: G_reg_opt.register_gradients(tf.reduce_mean(G_reg * G_reg_interval), G_gpu.trainables)
                if D_reg is not None: D_reg_opt.register_gradients(tf.reduce_mean(D_reg * D_reg_interval), D_gpu.trainables)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    # Setup training ops.
    data_fetch_op = tf.group(*data_fetch_ops)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()
    G_reg_op = G_reg_opt.apply_updates(allow_no_op=True)
    D_reg_op = D_reg_opt.apply_updates(allow_no_op=True)
    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    # Finalize graph.
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    tflib.init_uninitialized_vars()

    # print('Initializing logs...')
    summary_log = tf.summary.FileWriter(dnnlib.make_run_dir_path())
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()

    print(' Training for %d kimg (%d left) \n' % (total_kimg, total_kimg-resume_kimg))
    dnnlib.RunContext.get().update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = dnnlib.RunContext.get().get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = -1
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    running_mb_counter = 0
    while cur_nimg < total_kimg * 1000:
        if dnnlib.RunContext.get().should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, **sched_args)
        assert sched.minibatch_size % (sched.minibatch_gpu * num_gpus) == 0
        training_set.configure(sched.minibatch_gpu) # , sched.lod
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        feed_dict = {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_size_in: sched.minibatch_size, minibatch_gpu_in: sched.minibatch_gpu}
        for _repeat in range(minibatch_repeats):
            rounds = range(0, sched.minibatch_size, sched.minibatch_gpu * num_gpus)
            run_G_reg = (lazy_regularization and running_mb_counter % G_reg_interval == 0)
            run_D_reg = (lazy_regularization and running_mb_counter % D_reg_interval == 0)
            cur_nimg += sched.minibatch_size
            running_mb_counter += 1

            # Fast path without gradient accumulation.
            if len(rounds) == 1:
                tflib.run([G_train_op, data_fetch_op], feed_dict)
                if run_G_reg:
                    tflib.run(G_reg_op, feed_dict)
                tflib.run([D_train_op, Gs_update_op], feed_dict)
                if run_D_reg:
                    tflib.run(D_reg_op, feed_dict)

            # Slow path with gradient accumulation.
            else:
                for _round in rounds:
                    tflib.run(G_train_op, feed_dict)
                if run_G_reg:
                    for _round in rounds:
                        tflib.run(G_reg_op, feed_dict)
                tflib.run(Gs_update_op, feed_dict)
                for _round in rounds:
                    tflib.run(data_fetch_op, feed_dict)
                    tflib.run(D_train_op, feed_dict)
                if run_D_reg:
                    for _round in rounds:
                        tflib.run(D_reg_op, feed_dict)

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_tick < 0 or cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            cur_time = time.time()
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = dnnlib.RunContext.get().get_time_since_last_update()
            total_time = dnnlib.RunContext.get().get_time_since_start() + resume_time

            if sched.lod == 0:
                left_kimg = total_kimg - cur_nimg / 1000
                left_sec = left_kimg * tick_time / tick_kimg
                finaltime = time.asctime(time.localtime(cur_time + left_sec))
                msg_final = '%ss left till %s ' % (shortime(left_sec), finaltime[11:16])
            else:
                msg_final = ''

            # Report progress.
            # print('tick %-4d kimg %-6.1f lod %-5.2f minibch %-3d:%d time %-8s min/tick %-6.3g %s sec/kimg %-7.3g gpumem %-4.1f %d lr %.2g ' % (
            print('tick %-4d kimg %-6.1f time %-8s  %s min/tick %-6.3g sec/kimg %-7.3g gpumem %-4.1f lr %.2g ' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                # autosummary('Progress/lod', sched.lod),
                # autosummary('Progress/minibatch', sched.minibatch_size),
                # autosummary('Progress/minibatch_gpu', sched.minibatch_gpu),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                msg_final,
                autosummary('Timing/min_per_tick', tick_time / 60),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                # autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30),
                sched.G_lrate))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if image_snapshot_ticks is not None and (cur_tick % image_snapshot_ticks == 0 or done):
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch_gpu)
                misc.save_image_grid(grid_fakes, dnnlib.make_run_dir_path('fake-%04d.%s' % (cur_nimg // 1000, ext)), drange=drange_net, grid_size=grid_size)
            if network_snapshot_ticks is not None and (cur_tick % network_snapshot_ticks == 0 or done):
                pkl = dnnlib.make_run_dir_path('snapshot-%d-%s%s-%04d.pkl' % (resolution, setname[-1], init_res_str, cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-%04d.pkl' % (setname[:-1], resolution, setname[-1], init_res_str, cur_nimg // 1000)))

            # Update summaries and RunContext.
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            dnnlib.RunContext.get().update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = dnnlib.RunContext.get().get_last_update_interval() - tick_time

    # Save final snapshot.
    misc.save_pkl((G, D, Gs), dnnlib.make_run_dir_path('snapshot-%d-%s%s-final.pkl' % (resolution, setname[-1], init_res_str)))
    misc.save_pkl((Gs), dnnlib.make_run_dir_path('%s-%d-%s%s-final.pkl' % (setname[:-1], resolution, setname[-1], init_res_str)))

    # All done.
    summary_log.close()
    training_set.close()
Exemplo n.º 30
0
    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
    for gpu in range(submit_config.num_gpus):