示例#1
0
def transitionAtoB(
    run_id          = 102,     # Run ID or network pkl to resume training from, None = start from scratch.
    snapshot        = None,
    num_gen_noise   = 100):
    #http://cedro3.com/ai/stylegan/

    # Initialize TensorFlow.
    tflib.init_tf()

    # Load pre-trained network.
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    # Print network details.
    Gs.print_layers()

    # Pick latent vector.
    rnd = np.random.RandomState(10)  # seed = 10
    latentsA = rnd.randn(1, Gs.input_shape[1])
    for _ in range(num_gen_noise):
        latents_ = rnd.randn(1, Gs.input_shape[1])
    latentsB = rnd.randn(1, Gs.input_shape[1])

    num_split = 39  # 2つのベクトルを39分割
    for i in range(30):
        latents = latentsB+(latentsA-latentsB)*i/num_split
        # Generate image.
        fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
        images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

        # Save image.
        os.makedirs(os.path.join(config.result_dir, 'video'), exist_ok=True)
        png_filename = os.path.join(os.path.join(config.result_dir, 'video'), 'photo'+'{0:04d}'.format(i)+'.png')
        PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
示例#2
0
def run_snapshot(submit_config, metric_args, run_id, snapshot):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot))
    run_dir = misc.locate_run_dir(run_id)
    network_pkl = misc.locate_network_pkl(run_dir, snapshot)
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
示例#3
0
def test_d(submit_config,
           resume_run_id,
           dataset_args,
           tf_config={},
           resume_snapshot=None):

    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    latents_1 = tf.placeholder(tf.float32)
    labels_1 = None

    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    w_1 = Gs.components.mapping.get_output_for(latents_1,
                                               labels_1,
                                               is_validation=True)
    fake_image_1_op = Gs.components.synthesis.get_output_for(
        w_1, is_validation=True, randomize_noise=False)

    reals, labels = training_set.get_minibatch_tf()

    lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])

    reals = process_reals(reals, lod_in, False, training_set.dynamic_range,
                          [-1, 1])

    d_pred_real = D.get_output_for(reals, labels_1)
    d_pred_fake = D.get_output_for(fake_image_1_op, labels_1)

    training_set.configure(1, 0)

    for i in range(15):
        latents_1_val = np.random.randn(1, *G.input_shape[1:])

        # d_pred, fake_image_1 = tflib.run([d_pred_op, fake_image_1_op], feed_dict={latents_1: latents_1_val, lod_in: 0})
        d_pred_real_, d_pred_fake_, real_image = tflib.run(
            [d_pred_real, d_pred_fake, reals],
            feed_dict={
                latents_1: latents_1_val,
                lod_in: 0
            })

        print(d_pred_real_, d_pred_fake_)
        misc.save_mri_image(real_image,
                            os.path.join(submit_config.run_dir,
                                         'real_{}.nii.gz'.format(i)),
                            drange=[-1, 1])
示例#4
0
def main(
    run_id          = 101,     # Run ID or network pkl to resume training from, None = start from scratch.
    snapshot        = None,     # Snapshot index to resume training from, None = autodetect.
    grid_size       = [1,1],
    image_shrink    = 1,
    image_zoom      = 1,
    duration_sec    = 3.0,
    smoothing_sec   = 1.0,
    mp4             = None,
    mp4_fps         = 30,
    mp4_codec       = 'libx265',
    mp4_bitrate     = '16M',
    random_seed     = 1000,
    minibatch_size  = 8 ):

    # Initialize TensorFlow.
    tflib.init_tf()

    # Load pre-trained network.
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    mp4 = '%s-lerp.mp4' % network_pkl
    num_frames = int(np.rint(duration_sec * mp4_fps))

    # Print network details.
    Gs.print_layers()

    # Pick latent vector.
    print('Generating latent vectors...')
    rnd = np.random.RandomState(5)
    shape = [num_frames, Gs.input_shape[1]]
    all_latents = rnd.randn(*shape).astype(np.float32)
    #all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0], mode='wrap')
    #all_latents /= np.sqrt(np.mean(np.square(all_latents)))

    all_dlatents = Gs.components.mapping.run(all_latents, None)
    all_dlatents = scipy.ndimage.gaussian_filter(all_dlatents, [smoothing_sec * mp4_fps] + [0]*2, mode='wrap')
    print (shape, all_latents.shape, all_dlatents.shape)

    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    for idx in range(num_frames):
        dlatents = all_dlatents[idx]
        images = Gs.run(dlatents, None, truncation_psi=0.0, randomize_noise=True, output_transform=fmt)
        png_filename = os.path.join(config.result_dir, 'frame_%s.png' % str(idx).zfill(8))
        PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
def mixing(submit_config, resume_run_id, tf_config = {}, resume_snapshot=None):

    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    latents_1_val = np.random.randn(1,*G.input_shape[1:])
    latents_2_val = np.random.randn(1,*G.input_shape[1:])

    # latents_2_val = latents_1_val

    latents_1 = tf.placeholder(tf.float32)
    labels_1 = tf.constant([[0,0,0,0,1,0]])

    latents_2 = tf.placeholder(tf.float32)
    labels_2 = tf.constant([[0,0,0,0,1,0]])

    w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True)
    w_2 = Gs.components.mapping.get_output_for(latents_2, labels_2, is_validation=True)

    # w_1_val = tflib.run(w_1)
    # w_2_val = tflib.run(w_2)

    fake_image_1_op = Gs.components.synthesis.get_output_for(w_1, is_validation=True, randomize_noise=False)
    fake_image_2_op = Gs.components.synthesis.get_output_for(w_2, is_validation=True, randomize_noise=False)

    fake_image_1 = tflib.run(fake_image_1_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})
    fake_image_2 = tflib.run(fake_image_2_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})

    misc.save_image(fake_image_1[0], os.path.join(submit_config.run_dir,'fake_image_1.png'), drange=[-1,1])
    misc.save_image(fake_image_2[0], os.path.join(submit_config.run_dir,'fake_image_2.png'), drange=[-1,1])

    for i in range(15):
        w_mix = tf.concat([w_1[:,:i],w_2[:,i:]], axis=1)
        fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False)
        fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})
        misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_12_{}.png'.format(i)), drange=[-1,1])

    for i in range(15):
        w_mix = tf.concat([w_2[:,:i],w_1[:,i:]], axis=1)
        fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False)
        fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})
        misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_21_{}.png'.format(i)), drange=[-1,1])
示例#6
0
def transitionAtoB_v2(
    run_id          = 102,     # Run ID or network pkl to resume training from, None = start from scratch.
    snapshot        = None,
    num_frames      = 20,
    interpolate_dim = 350):

    # Initialize TensorFlow.
    tflib.init_tf()

    # Load pre-trained network.
    network_pkl = misc.locate_network_pkl(run_id, snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    # Print network details.
    Gs.print_layers()

    # Pick latent vector.
    rnd = np.random.RandomState(10)  # seed = 10
    init_latent = rnd.randn(1, Gs.input_shape[1])[0]

    def apply_latent_fudge(fudge):
      copy = np.copy(init_latent)
      copy[interpolate_dim] += fudge
      return copy

    interpolate = np.linspace(0., 30., num_frames) - 15
    latents = np.array(list(map(apply_latent_fudge, interpolate)))

    fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)

    os.makedirs(os.path.join(config.result_dir, os.path.basename(network_pkl).replace(".mp4","")), exist_ok=True)
    for idx in range(num_frames):
        # Save image.
        png_filename = os.path.join(config.result_dir, os.path.basename(network_pkl).replace(".mp4",""), 'frame_'+'{0:04d}'.format(idx)+'.png')
        PIL.Image.fromarray(images[idx], 'RGB').save(png_filename)
    resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
    resume_kimg             = 0.0       # Assumed training progress at the beginning. Affects reporting and training schedule.
    #resume_kimg             = 5645,
    resume_time             = 0.0):     # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
示例#8
0
def training_loop(
        submit_config,
        Encoder_args={},
        E_opt_args={},
        D_opt_args={},
        E_loss_args={},
        D_loss_args={},
        lr_args=EasyDict(),
        tf_config={},
        dataset_args=EasyDict(),
        decoder_pkl=EasyDict(),
        drange_data=[0, 255],
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        mirror_augment=False,
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        max_iters=150000,
        E_smoothing=0.999):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        real_train = tf.placeholder(tf.float32, [
            submit_config.batch_size, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                    name='real_image_train')
        real_test = tf.placeholder(tf.float32, [
            submit_config.batch_size_test, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                   name='real_image_test')
        real_split = tf.split(real_train,
                              num_or_size_splits=submit_config.num_gpus,
                              axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs, NE = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')
                        [0]) // submit_config.batch_size
        else:
            print('Constructing networks...')
            G, D, Gs, NE = misc.load_pkl(decoder_pkl.decoder_pkl)
            E = tflib.Network('E',
                              size=submit_config.image_size,
                              filter=64,
                              filter_max=1024,
                              phase=True,
                              **Encoder_args)
            start = 0

    Gs.print_layers()
    E.print_layers()
    D.print_layers()

    global_step = tf.Variable(start,
                              trainable=False,
                              name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate,
                                               global_step,
                                               lr_args.decay_step,
                                               lr_args.decay_rate,
                                               staircase=lr_args.stair)
    add_global = global_step.assign_add(1)
    E_opt = tflib.Optimizer(name='TrainE',
                            learning_rate=learning_rate,
                            **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=learning_rate,
                            **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            perceptual_model = PerceptualModel(
                img_size=[submit_config.image_size, submit_config.image_size],
                multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment,
                                     drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(
                    E=E_gpu,
                    G=G_gpu,
                    D=D_gpu,
                    perceptual_model=perceptual_model,
                    reals=real_gpu,
                    **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(
                    E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    #Es_update_op = Es.setup_as_moving_average_of(E, beta=E_smoothing)

    print('building testing graph...')
    fake_X_val = test(E, Gs, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess,
                                       data_dir=dataset_args.data_train,
                                       submit_config=submit_config,
                                       mode='train')
    image_batch_test = get_train_data(sess,
                                      data_dir=dataset_args.data_test,
                                      submit_config=submit_config,
                                      mode='test')

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        E.setup_weight_histograms()
        D.setup_weight_histograms()

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    print('Optimization starts!!!')
    for it in range(start, max_iters):

        feed_dict = {real_train: sess.run(image_batch_train)}
        sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict)
        sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad],
                 feed_dict)

        cur_nimg += submit_config.batch_size

        if it % 100 == 0:
            print("Iter: %06d  kimg: %-8.1f time: %-12s" %
                  (it, cur_nimg / 1000,
                   dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg

            if cur_tick % image_snapshot_ticks == 0:
                batch_images_test = sess.run(image_batch_test)
                batch_images_test = misc.adjust_dynamic_range(
                    batch_images_test.astype(np.float32), [0, 255], [-1., 1.])

                samples2 = sess.run(fake_X_val,
                                    feed_dict={real_test: batch_images_test})
                samples2 = samples2.transpose(0, 2, 3, 1)
                batch_images_test = batch_images_test.transpose(0, 2, 3, 1)
                orin_recon = np.concatenate([batch_images_test, samples2],
                                            axis=0)
                imwrite(immerge(orin_recon, 2, submit_config.batch_size_test),
                        '%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg))

            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir,
                                   'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs, NE), pkl)

    misc.save_pkl((E, G, D, Gs, NE),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
示例#9
0
def training_loop(
                  submit_config,
                  Encoder_args            = {},
                  E_opt_args              = {},
                  D_opt_args              = {},
                  E_loss_args             = EasyDict(),
                  D_loss_args             = {},
                  lr_args                 = EasyDict(),
                  tf_config               = {},
                  dataset_args            = EasyDict(),
                  decoder_pkl             = EasyDict(),
                  drange_data             = [0, 255],
                  drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
                  mirror_augment          = False,
                  resume_run_id           = config.ENCODER_PICKLE_DIR,     # Run ID or network pkl to resume training from, None = start from scratch.
                  resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
                  image_snapshot_ticks    = 1,        # How often to export image snapshots?
                  network_snapshot_ticks  = 4,       # How often to export network snapshots?
                  max_iters               = 150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        real_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='real_image_train')
        real_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='real_image_test')
        real_split = tf.split(real_train, num_or_size_splits=submit_config.num_gpus, axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, D, Gs = misc.load_pkl(decoder_pkl.decoder_pkl)
            num_layers = Gs.components.synthesis.input_shape[1]
            E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args)
            start = 0

    E.print_layers(); Gs.print_layers(); D.print_layers()

    global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step,
                                               lr_args.decay_rate, staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment, drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, perceptual_model=perceptual_model, reals=real_gpu, **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('building testing graph...')
    fake_X_val = test(E, Gs, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train')
    image_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test')

    summary_log = tf.summary.FileWriter(config.GDRIVE_PATH)

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    init_pascal = tf.initialize_variables(
        [global_step0],
        name='init_pascal'
    )
    sess.run(init_pascal)
    
    print('Optimization starts!!!')
    
    
    for it in range(start, max_iters):

        batch_images = sess.run(image_batch_train)
        feed_dict_1 = {real_train: batch_images}
        _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1)
        _, d_r_, d_f_, d_g_ = sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1)

        cur_nimg += submit_config.batch_size

        if it % 50 == 0:
            print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % (
                it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)
            
            
            
            
        if it % 500 == 0:
            batch_images_test = sess.run(image_batch_test)
            batch_images_test = misc.adjust_dynamic_range(batch_images_test.astype(np.float32), [0, 255], [-1., 1.])
            samples2 = sess.run(fake_X_val, feed_dict={real_test: batch_images_test})
            orin_recon = np.concatenate([batch_images_test, samples2], axis=0)
            orin_recon = adjust_pixel_range(orin_recon)
            orin_recon = fuse_images(orin_recon, row=2, col=submit_config.batch_size_test)
            # save image results during training, first row is original images and the second row is reconstructed images
            save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), orin_recon)

            # save image to gdrive
            img_path = os.path.join(config.GDRIVE_PATH, 'images', ('iter_%08d.png' % (cur_nimg)))
            save_image(img_path, orin_recon)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg



            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)
                
                # save network snapshot to gdrive
                pkl_drive = os.path.join(config.GDRIVE_PATH, 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl_drive)

    misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
示例#10
0
def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, mirror_augment, metrics, resume_id):
    train     = EasyDict(run_func_name='training.training_loop.training_loop') # Options for training loop.
    G         = EasyDict(func_name='training.networks_stylegan2.G_main')       # Options for generator network.
    D         = EasyDict(func_name='training.networks_stylegan2.D_stylegan2')  # Options for discriminator network.
    G_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for generator optimizer.
    D_opt     = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8)                  # Options for discriminator optimizer.
    G_loss    = EasyDict(func_name='training.loss.G_logistic_ns_pathreg_face')      # Options for generator loss.
    D_loss    = EasyDict(func_name='training.loss.D_logistic_r1')              # Options for discriminator loss.
    sched     = EasyDict()                                                     # Options for TrainingSchedule.
    grid      = EasyDict(size='8k', layout='random')                           # Options for setup_snapshot_image_grid().
    sc        = dnnlib.SubmitConfig()                                          # Options for dnnlib.submit_run().
    tf_config = {'rnd.np_random_seed': 1000}                                   # Options for tflib.init_tf().

    train.resume_pkl = misc.locate_network_pkl(resume_id, None)
    # train.resume_pkl = './results/00132-stylegan2-11_12-1gpu-config-f/network-snapshot-015041.pkl'
    train.resume_kimg = int(os.path.basename(train.resume_pkl).split('.')[0].split('-')[2])
    train.image_snapshot_ticks    = 1,       # How often to save image snapshots? None = only save 'reals.png' and 'fakes-init.png'.
    train.network_snapshot_ticks  = 1,       # How often to save network snapshots? None = only save 'networks-final.pkl'.

    train.data_dir = data_dir
    train.total_kimg = total_kimg
    train.mirror_augment = mirror_augment
    train.image_snapshot_ticks = train.network_snapshot_ticks = 10
    sched.G_lrate_base = sched.D_lrate_base = 0.002
    sched.minibatch_size_base = 32
    sched.minibatch_gpu_base = 4
    D_loss.gamma = 10
    metrics = [metric_defaults[x] for x in metrics]
    desc = 'stylegan2'

    desc += '-' + dataset
    dataset_args = EasyDict(tfrecord_dir=dataset)

    assert num_gpus in [1, 2, 4, 8]
    sc.num_gpus = num_gpus
    desc += '-%dgpu' % num_gpus

    assert config_id in _valid_configs
    desc += '-' + config_id

    # Configs A-E: Shrink networks to match original StyleGAN.
    if config_id != 'config-f':
        G.fmap_base = D.fmap_base = 8 << 10

    # Config E: Set gamma to 100 and override G & D architecture.
    if config_id.startswith('config-e'):
        D_loss.gamma = 100
        if 'Gorig'   in config_id: G.architecture = 'orig'
        if 'Gskip'   in config_id: G.architecture = 'skip' # (default)
        if 'Gresnet' in config_id: G.architecture = 'resnet'
        if 'Dorig'   in config_id: D.architecture = 'orig'
        if 'Dskip'   in config_id: D.architecture = 'skip'
        if 'Dresnet' in config_id: D.architecture = 'resnet' # (default)

    # Configs A-D: Enable progressive growing and switch to networks that support it.
    if config_id in ['config-a', 'config-b', 'config-c', 'config-d']:
        sched.lod_initial_resolution = 8
        sched.G_lrate_base = sched.D_lrate_base = 0.001
        sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
        sched.minibatch_size_base = 32 # (default)
        sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32}
        sched.minibatch_gpu_base = 4 # (default)
        sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4}
        G.synthesis_func = 'G_synthesis_stylegan_revised'
        D.func_name = 'training.networks_stylegan2.D_stylegan'

    # Configs A-C: Disable path length regularization.
    if config_id in ['config-a', 'config-b', 'config-c']:
        G_loss = EasyDict(func_name='training.loss.G_logistic_ns')

    # Configs A-B: Disable lazy regularization.
    if config_id in ['config-a', 'config-b']:
        train.lazy_regularization = False

    # Config A: Switch to original StyleGAN networks.
    if config_id == 'config-a':
        G = EasyDict(func_name='training.networks_stylegan.G_style')
        D = EasyDict(func_name='training.networks_stylegan.D_basic')

    if gamma is not None:
        D_loss.gamma = gamma

    sc.submit_target = dnnlib.SubmitTarget.LOCAL
    sc.local.do_not_copy_source_files = True
    kwargs = EasyDict(train)
    kwargs.update(G_args=G, D_args=D, G_opt_args=G_opt, D_opt_args=D_opt, G_loss_args=G_loss, D_loss_args=D_loss)
    kwargs.update(dataset_args=dataset_args, sched_args=sched, grid_args=grid, metric_arg_list=metrics, tf_config=tf_config)
    kwargs.submit_config = copy.deepcopy(sc)
    kwargs.submit_config.run_dir_root = result_dir
    kwargs.submit_config.run_desc = desc
    dnnlib.submit_run(**kwargs)
def training_loop(
    submit_config,
    G_args                  = {},       # 生成网络的设置。
    D_args                  = {},       # 判别网络的设置。
    G_opt_args              = {},       # 生成网络优化器设置。
    D_opt_args              = {},       # 判别网络优化器设置。
    G_loss_args             = {},       # 生成损失设置。
    D_loss_args             = {},       # 判别损失设置。
    dataset_args            = {},       # 数据集设置。
    sched_args              = {},       # 训练计划设置。
    grid_args               = {},       # setup_snapshot_image_grid()相关设置。
    metric_arg_list         = [],       # 指标方法设置。
    tf_config               = {},       # tflib.init_tf()相关设置。
    G_smoothing_kimg        = 10.0,     # 生成器权重的运行平均值的半衰期。
    D_repeats               = 1,        # G每迭代一次训练判别器多少次。
    minibatch_repeats       = 4,        # 调整训练参数前要运行的minibatch的数量。
    reset_opt_for_new_lod   = True,     # 引入新层时是否重置优化器内部状态(例如Adam时刻)?
    total_kimg              = 15000,    # 训练的总长度,以成千上万个真实图像为统计。
    mirror_augment          = False,    # 启用镜像增强?
    drange_net              = [-1,1],   # 将图像数据馈送到网络时使用的动态范围。
    image_snapshot_ticks    = 1,        # 多久导出一次图像快照?
    network_snapshot_ticks  = 10,       # 多久导出一次网络模型存储?
    save_tf_graph           = False,    # 在tfevents文件中包含完整的TensorFlow计算图吗?
    save_weight_histograms  = False,    # 在tfevents文件中包括权重直方图?
    resume_run_id           = None,     # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。
    resume_snapshot         = None,     # 要从哪恢复训练的快照的索引,None = 自动检测。
    resume_kimg             = 0.0,      # 在训练开始时给定当前训练进度。影响报告和训练计划。
    resume_time             = 0.0):     # 在训练开始时给定统计时间。影响报告。

    # 初始化dnnlib和TensorFlow。
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # 载入训练集。
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # 构建网络。
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()
    # 构建计算图与优化器
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    # 设置快照图像网格
    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
    # 建立运行目录
    print('Setting up run dir...')
    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)
    # 训练
    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # 选择训练参数并配置训练操作。
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # 进行训练。
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # 每个tick执行一次维护任务。
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # 报告进度。
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # 保存快照。
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)

            # 更新摘要和RunContext。
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # 保存最终结果。
    misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
示例#12
0
def training_loop(
                  submit_config,
                  Encoder_args            = {},
                  E_opt_args              = {},
                  D_opt_args              = {},
                  E_loss_args             = EasyDict(),
                  D_loss_args             = {},
                  lr_args                 = EasyDict(),
                  tf_config               = {},
                  dataset_args            = EasyDict(),
                  decoder_pkl             = EasyDict(),
                  inversion_pkl           = EasyDict(),
                  drange_data             = [0, 255],
                  drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
                  mirror_augment          = False,
                  resume_run_id           = config.ENCODER_PICKLE_DIR,     # Run ID or network pkl to resume training from, None = start from scratch.
                  resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
                  image_snapshot_ticks    = 1,        # How often to export image snapshots?
                  network_snapshot_ticks  = 4,       # How often to export network snapshots?
                  max_iters               = 150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('input'):
        placeholder_real_portraits_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_train')
        placeholder_real_landmarks_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_train')
        placeholder_real_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_train')
        placeholder_landmarks_shuffled_train = tf.placeholder(tf.float32, [submit_config.batch_size, 3, submit_config.image_size, submit_config.image_size], name='placeholder_landmarks_shuffled_train')


        placeholder_real_portraits_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_portraits_test')
        placeholder_real_landmarks_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_test')
        placeholder_real_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_shuffled_test')
        placeholder_real_landmarks_shuffled_test = tf.placeholder(tf.float32, [submit_config.batch_size_test, 3, submit_config.image_size, submit_config.image_size], name='placeholder_real_landmarks_shuffled_test')

        real_split_landmarks = tf.split(placeholder_real_landmarks_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_portraits = tf.split(placeholder_real_portraits_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_shuffled = tf.split(placeholder_real_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        real_split_lm_shuffled = tf.split(placeholder_landmarks_shuffled_train, num_or_size_splits=submit_config.num_gpus, axis=0)
        
        placeholder_training_flag = tf.placeholder(tf.string, name='placeholder_training_flag')

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            start = int(network_pkl.split('-')[-1].split('.')[0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, _, Gs = misc.load_pkl(decoder_pkl.decoder_pkl) # don't use pre-trained discriminator!
            num_layers = Gs.components.synthesis.input_shape[1]

            # here we add a new discriminator!
            D = tflib.Network('D',  # name of the network how we call it
                              num_channels=3, resolution=128, label_size=0,  #some needed for this build function
                              func_name="training.networks_stylegan.D_basic") # function of that network. more was not passed in d_args!
                              # input is not passed here (just construction - note that we do not call the actual function!). Instead, network will inspect build function and require it for the get_output_for function.
            print("Created new Discriminator!")

            E = tflib.Network('E', size=submit_config.image_size, filter=64, filter_max=1024, num_layers=num_layers, phase=True, **Encoder_args)
            start = 0
        Inv, _, _, _ = misc.load_pkl(inversion_pkl.inversion_pkl)

    E.print_layers(); Gs.print_layers(); D.print_layers()

    global_step0 = tf.Variable(start, trainable=False, name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate, global_step0, lr_args.decay_step,
                                               lr_args.decay_rate, staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE', learning_rate=learning_rate, **E_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=learning_rate, **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('build graph on gpu %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = Gs if gpu == 0 else Gs.clone(Gs.name + '_shadow')
            Inv_gpu = Inv if gpu == 0 else Inv.clone(Inv.name + '_shadow')
            perceptual_model = PerceptualModel(img_size=[E_loss_args.perceptual_img_size, E_loss_args.perceptual_img_size], multi_layers=False)
            real_portraits_gpu = process_reals(real_split_portraits[gpu], mirror_augment, drange_data, drange_net)
            shuffled_portraits_gpu = process_reals(real_split_shuffled[gpu], mirror_augment, drange_data, drange_net)
            real_landmarks_gpu = process_reals(real_split_landmarks[gpu], mirror_augment, drange_data, drange_net)
            shuffled_landmarks_gpu = process_reals(real_split_lm_shuffled[gpu], mirror_augment, drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, perceptual_model=perceptual_model, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, shuffled_landmarks=shuffled_landmarks_gpu, training_flag=placeholder_training_flag, **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(E=E_gpu, G=G_gpu, D=D_gpu, Inv=Inv_gpu, real_portraits=real_portraits_gpu, shuffled_portraits=shuffled_portraits_gpu, real_landmarks=real_landmarks_gpu, training_flag=placeholder_training_flag, **D_loss_args) # change signature in ...
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    print('building testing graph...')
    fake_X_val = test(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config)
    inv_X_val = test_inversion(E, Gs, Inv, placeholder_real_portraits_test, placeholder_real_landmarks_test, placeholder_real_shuffled_test, submit_config)
    
    #sampled_portraits_val = sample_random_portraits(Gs, submit_config.batch_size)
    #sampled_portraits_val_test = sample_random_portraits(Gs, submit_config.batch_size_test)

    sess = tf.get_default_session()

    print('Getting training data...')
    # x_batch is a batch of (2, ..., ..., ...) records!
    stack_batch_train = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train')
    stack_batch_test = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test')
    
    stack_batch_train_secondary = get_train_data(sess, data_dir=dataset_args.data_train, submit_config=submit_config, mode='train_secondary')
    stack_batch_test_secondary = get_train_data(sess, data_dir=dataset_args.data_test, submit_config=submit_config, mode='test_secondary')

    summary_log = tf.summary.FileWriter(config.getGdrivePath())

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    init_fix = tf.initialize_variables(
        [global_step0],
        name='init_fix'
    )
    sess.run(init_fix)
    
    print('Optimization starts!!!')
    
    
    # here is the actual training loop: all iterations
    for it in range(start, max_iters):
        batch_stacks = sess.run(stack_batch_train)
        batch_portraits = batch_stacks[:,0,:,:,:]
        batch_landmarks = batch_stacks[:,1,:,:,:]
        
        batch_stacks_secondary = sess.run(stack_batch_train_secondary)
        batch_shuffled = batch_stacks_secondary[:,0,:,:,:]
        batch_lm_shuffled = batch_stacks_secondary[:,1,:,:,:]
        
        
        training_flag = "pose"
        
        feed_dict_1 = {placeholder_real_portraits_train: batch_portraits, placeholder_real_landmarks_train: batch_landmarks, placeholder_real_shuffled_train:batch_shuffled, placeholder_landmarks_shuffled_train:batch_lm_shuffled, placeholder_training_flag: training_flag}
        # here we query these encoder- and discriminator losses. as input we provide: batch_stacks = batch of images + landmarks.
        _, recon_, adv_ = sess.run([E_train_op, E_loss_rec, E_loss_adv], feed_dict_1)
        _, d_r_, d_f_, d_g_= sess.run([D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict_1)

        cur_nimg += submit_config.batch_size

        if it % 50 == 0:
            print('Iter: %06d recon_loss: %-6.4f adv_loss: %-6.4f d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f time:%-12s' % (
                it, recon_, adv_, d_r_, d_f_, d_g_, dnnlib.util.format_time(time.time() - start_time)))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)
            
            
            
            
        if it % 500 == 0:
            batch_stacks_test = sess.run(stack_batch_test)
            batch_portraits_test = batch_stacks_test[:,0,:,:,:]
            batch_landmarks_test = batch_stacks_test[:,1,:,:,:]
            
            batch_stacks_test_secondary = sess.run(stack_batch_test_secondary)
            batch_shuffled_test = batch_stacks_test_secondary[:,0,:,:,:]
            batch_shuffled_lm_test = batch_stacks_test_secondary[:,1,:,:,:]
            
            
            batch_portraits_test = misc.adjust_dynamic_range(batch_portraits_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_landmarks_test = misc.adjust_dynamic_range(batch_landmarks_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_shuffled_test = misc.adjust_dynamic_range(batch_shuffled_test.astype(np.float32), [0, 255], [-1., 1.])
            batch_shuffled_lm_test = misc.adjust_dynamic_range(batch_shuffled_lm_test.astype(np.float32), [0, 255], [-1., 1.])

            # first: input + target landmarks = manipulated image
            samples_manipulated = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_shuffled_lm_test})

            # 2nd: manipulated + original landmarks
            samples_reconstructed = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: samples_manipulated, placeholder_real_landmarks_test: batch_landmarks_test})

            # also: show direct reconstruction
            samples_direct_rec = sess.run(fake_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test})

            # show results of the inverison
            portraits_inverted = sess.run(inv_X_val, feed_dict={placeholder_real_portraits_test: batch_portraits_test, placeholder_real_landmarks_test: batch_landmarks_test})

            # show: original portrait, original landmark, diret reconstruction, fake landmark, manipulated, rec.
            debug_img = np.concatenate([
                    batch_landmarks_test, # original landmarks
                    batch_portraits_test, # original portraits,
                    samples_direct_rec, # direct
                    batch_shuffled_lm_test, # shuffled landmarks
                    samples_manipulated, # manipulated images
                    samples_reconstructed,
                    portraits_inverted# cycle reconstructed images
                ], axis=0)

            debug_img = adjust_pixel_range(debug_img)
            debug_img = fuse_images(debug_img, row=6, col=submit_config.batch_size_test)
            # save image results during training, first row is original images and the second row is reconstructed images
            save_image('%s/iter_%08d.png' % (submit_config.run_dir, cur_nimg), debug_img)

            # save image to gdrive
            img_path = os.path.join(config.getGdrivePath(), 'images', ('iter_%08d.png' % (cur_nimg)))
            save_image(img_path, debug_img)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg



            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)
                
                # save network snapshot to gdrive
                pkl_drive = os.path.join(config.getGdrivePath(), 'snapshots', 'network-snapshot-%08d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl_drive)

    misc.save_pkl((E, G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
示例#13
0
def training_loop(
    submit_config,
    HP_args={},  # Options for the Hessian Penalty.
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    interp_snapshot_ticks=20,  # How often to generate interpolation visualizations in TensorBoard?
    network_snapshot_ticks=5,  # How often to export network snapshots?
    network_metric_ticks=5,  # How often to evaluate network snapshots on specified metrics?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Create a copy of dataset_args for running the metrics:
    metrics_dataset_args = deepcopy(dataset_args)
    metrics_dataset_args.shuffle_mb = 0

    print('Saving interp videos every %s ticks' % interp_snapshot_ticks)
    print('Saving network snapshot every %s ticks' % network_snapshot_ticks)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    # G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder:
    hessian_weight = tf.placeholder(tf.float32,
                                    name='hessian_weight',
                                    shape=[])

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    reg_ops = [
    ]  # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    hp_lambda=hessian_weight,
                    HP_args=HP_args,
                    gpu_ix=gpu,
                    lod_in=lod_in,
                    max_lod=training_set.resolution_log2,
                    **G_loss_args)
                if HP_args.hp_lambda > 0:
                    reg_ops += [G_hessian_penalty]
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss, mutual_info = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    gpu_ix=gpu,
                    infogan_nz=D_args.infogan_nz,
                    **D_loss_args)
                # print([name for name in D_gpu.trainables.keys()])
                # gps = [weight for name, weight in G_gpu.trainables.items()][0]
                # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0]
                # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2))
                # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2))
                # reg_ops.extend([dg, gg, dps, gps])
                if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0:
                    reg_ops += [mutual_info]
            # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero
            # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero):
            G_opt.register_gradients(
                G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables)
            D_opt.register_gradients(
                tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info,
                D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up snapshot interpolation...')
    nz = G.input_shapes[0][1]
    interp_steps = 24  # Number of frames in the visualization
    interp_batch_size = 8  # Number of gifs per row of visualization
    interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps)
    interp_labels = np.zeros(
        [interp_steps * interp_batch_size * nz, training_set.label_size],
        dtype=training_set.label_dtype)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_summary(
        build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'),
                            'samples/real'), 0)
    summary_log.add_summary(
        build_image_summary(
            os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg),
            'samples/Gs'), resume_kimg)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999:
        print('Generating initial interpolations...')
        vis_tools.make_and_save_interpolation_gifs(
            Gs,
            interp_z,
            interp_labels,
            minibatch_size=sched.minibatch // submit_config.num_gpus,
            interp_steps=interp_steps,
            interp_batch_size=interp_batch_size,
            cur_kimg=resume_kimg,
            summary_log=summary_log)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    num_G_grad_steps = 0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop():
            break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            cur_hessian_weight = get_current_hessian_penalty_loss_weight(
                HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg,
                HP_args.warmup_nimg)
            tflib.run(
                [G_train_op] + reg_ops, {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    hessian_weight: cur_hessian_weight
                })
            num_G_grad_steps += 1

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   autosummary('Progress/hessian_weight', cur_hessian_weight),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            autosummary('Progress/G_grad_steps', num_G_grad_steps)

            # Save snapshots and fake image samples (for both Gs and G):
            if cur_tick % image_snapshot_ticks == 0 or done:
                iter = (cur_nimg // 1000)
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes_inst = G.run(grid_latents,
                                        grid_labels,
                                        is_validation=True,
                                        minibatch_size=sched.minibatch //
                                        submit_config.num_gpus)
                fake_path = os.path.join(submit_config.run_dir,
                                         'fakes%06d.png' % iter)
                ifake_path = os.path.join(submit_config.run_dir,
                                          'ifakes%06d.png' % iter)
                misc.save_image_grid(grid_fakes,
                                     fake_path,
                                     drange=drange_net,
                                     grid_size=grid_size)
                misc.save_image_grid(grid_fakes_inst,
                                     ifake_path,
                                     drange=drange_net,
                                     grid_size=grid_size)
                summary_log.add_summary(
                    build_image_summary(fake_path, 'samples/Gs'), iter)
                summary_log.add_summary(
                    build_image_summary(ifake_path, 'samples/G'), iter)

            # Generate/Save Interpolation Visualizations:
            if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0:
                vis_tools.make_and_save_interpolation_gifs(
                    Gs,
                    interp_z,
                    interp_labels,
                    minibatch_size=sched.minibatch // submit_config.num_gpus,
                    interp_steps=interp_steps,
                    interp_batch_size=interp_batch_size,
                    cur_kimg=cur_nimg // 1000,
                    summary_log=summary_log)

            # Save snapshot and run metrics:
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1:
                    metrics.run(pkl,
                                dataset_args=metrics_dataset_args,
                                mirror_augment=mirror_augment,
                                num_gpus=submit_config.num_gpus,
                                tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir,
                               'network-snapshot-%06d.pkl' % total_kimg))
    summary_log.close()

    ctx.close()
示例#14
0
def mixing(resume_run_id, resume_snapshot=None):
	network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
示例#15
0
def training_loop(
        submit_config,
        Encoder_args={},
        D_args={},
        G_args={},
        E_opt_args={},
        D_opt_args={},
        E_loss_args=EasyDict(),
        D_loss_args={},
        lr_args=EasyDict(),
        tf_config={},
        dataset_args=EasyDict(),
        decoder_pkl=EasyDict(),
        drange_data=[0, 255],
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        mirror_augment=False,
        filter=64,  # Minimum number of feature maps in any layer.
        filter_max=512,  # Maximum number of feature maps in any layer.
        resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        d_scale=0.1,  # Decide whether to update discriminator.
        pretrained_D=True,  # Whether to use pre trained Discriminator.
        max_iters=150000):

    tflib.init_tf(tf_config)

    with tf.name_scope('Input'):
        real_train = tf.placeholder(tf.float32, [
            submit_config.batch_size, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                    name='real_image_train')
        real_test = tf.placeholder(tf.float32, [
            submit_config.batch_size_test, 3, submit_config.image_size,
            submit_config.image_size
        ],
                                   name='real_image_test')
        real_split = tf.split(real_train,
                              num_or_size_splits=submit_config.num_gpus,
                              axis=0)

    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            E, G, D, Gs = misc.load_pkl(network_pkl)
            G_style_mod = tflib.Network('G_StyleMod',
                                        resolution=submit_config.image_size,
                                        label_size=0,
                                        **G_args)
            start = int(network_pkl.split('-')[-1].split('.')
                        [0]) // submit_config.batch_size
            print('Start: ', start)
        else:
            print('Constructing networks...')
            G, PreD, Gs = misc.load_pkl(decoder_pkl.decoder_pkl)
            num_layers = Gs.components.synthesis.input_shape[1]
            E = tflib.Network('E_gpu0',
                              size=submit_config.image_size,
                              filter=filter,
                              filter_max=filter_max,
                              num_layers=num_layers,
                              is_training=True,
                              num_gpus=submit_config.num_gpus,
                              **Encoder_args)
            OriD = tflib.Network('D_ori',
                                 resolution=submit_config.image_size,
                                 label_size=0,
                                 **D_args)
            G_style_mod = tflib.Network('G_StyleMod',
                                        resolution=submit_config.image_size,
                                        label_size=0,
                                        **G_args)
            if pretrained_D:
                D = PreD
            else:
                D = OriD
            start = 0
        Gs_vars_pairs = {
            name: tflib.run(val)
            for name, val in Gs.components.synthesis.vars.items()
        }
        for g_name, g_val in G_style_mod.vars.items():
            tflib.set_vars({g_val: Gs_vars_pairs[g_name]})

    E.print_layers()
    Gs.print_layers()
    D.print_layers()

    global_step0 = tf.Variable(start,
                               trainable=False,
                               name='learning_rate_step')
    learning_rate = tf.train.exponential_decay(lr_args.learning_rate,
                                               global_step0,
                                               lr_args.decay_step,
                                               lr_args.decay_rate,
                                               staircase=lr_args.stair)
    add_global0 = global_step0.assign_add(1)

    E_opt = tflib.Optimizer(name='TrainE',
                            learning_rate=learning_rate,
                            **E_opt_args)
    if d_scale > 0:
        D_opt = tflib.Optimizer(name='TrainD',
                                learning_rate=learning_rate,
                                **D_opt_args)

    E_loss_rec = 0.
    E_loss_adv = 0.
    D_loss_real = 0.
    D_loss_fake = 0.
    D_loss_grad = 0.
    for gpu in range(submit_config.num_gpus):
        print('Building Graph on GPU %s' % str(gpu))
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            E_gpu = E if gpu == 0 else E.clone(E.name[:-1] + str(gpu))
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            G_gpu = G_style_mod if gpu == 0 else G_style_mod.clone(
                G_style_mod.name + '_shadow')
            feature_model = PerceptualModel(img_size=[
                E_loss_args.perceptual_img_size,
                E_loss_args.perceptual_img_size
            ],
                                            multi_layers=False)
            real_gpu = process_reals(real_split[gpu], mirror_augment,
                                     drange_data, drange_net)
            with tf.name_scope('E_loss'), tf.control_dependencies(None):
                E_loss, recon_loss, adv_loss = dnnlib.util.call_func_by_name(
                    E=E_gpu,
                    G=G_gpu,
                    D=D_gpu,
                    feature_model=feature_model,
                    reals=real_gpu,
                    **E_loss_args)
                E_loss_rec += recon_loss
                E_loss_adv += adv_loss
            with tf.name_scope('D_loss'), tf.control_dependencies(None):
                D_loss, loss_fake, loss_real, loss_gp = dnnlib.util.call_func_by_name(
                    E=E_gpu, G=G_gpu, D=D_gpu, reals=real_gpu, **D_loss_args)
                D_loss_real += loss_real
                D_loss_fake += loss_fake
                D_loss_grad += loss_gp
            with tf.control_dependencies([add_global0]):
                E_opt.register_gradients(E_loss, E_gpu.trainables)
                if d_scale > 0:
                    D_opt.register_gradients(D_loss, D_gpu.trainables)

    E_loss_rec /= submit_config.num_gpus
    E_loss_adv /= submit_config.num_gpus
    D_loss_real /= submit_config.num_gpus
    D_loss_fake /= submit_config.num_gpus
    D_loss_grad /= submit_config.num_gpus

    E_train_op = E_opt.apply_updates()
    if d_scale > 0:
        D_train_op = D_opt.apply_updates()

    print('Building testing graph...')
    fake_X_val = test(E, G_style_mod, real_test, submit_config)

    sess = tf.get_default_session()

    print('Getting training data...')
    image_batch_train = get_train_data(sess,
                                       data_dir=dataset_args.data_train,
                                       submit_config=submit_config,
                                       mode='train')
    image_batch_test = get_train_data(sess,
                                      data_dir=dataset_args.data_test,
                                      submit_config=submit_config,
                                      mode='test')

    summary_log = tf.summary.FileWriter(submit_config.run_dir)

    cur_nimg = start * submit_config.batch_size
    cur_tick = 0
    tick_start_nimg = cur_nimg
    start_time = time.time()

    print('Optimization starts!!!')
    for it in range(start, max_iters):

        batch_images = sess.run(image_batch_train)
        feed_dict = {real_train: batch_images}
        _, recon_, adv_, lr = sess.run(
            [E_train_op, E_loss_rec, E_loss_adv, learning_rate], feed_dict)
        if d_scale > 0:
            _, d_r_, d_f_, d_g_ = sess.run(
                [D_train_op, D_loss_real, D_loss_fake, D_loss_grad], feed_dict)

        cur_nimg += submit_config.batch_size

        run_time = time.time() - start_time
        iter_time = run_time / (it - start + 1)
        eta_time = iter_time * (max_iters - it - 1)

        if it % 50 == 0:
            print(
                'Iter: %06d/%d, lr: %-.8f recon_loss: %-6.4f adv_loss: %-6.4f run_time:%-12s eta_time:%-12s'
                % (it, max_iters, lr, recon_, adv_,
                   dnnlib.util.format_time(time.time() - start_time),
                   dnnlib.util.format_time(eta_time)))
            if d_scale > 0:
                print('d_r_loss: %-6.4f d_f_loss: %-6.4f d_reg: %-6.4f ' %
                      (d_r_, d_f_, d_g_))
            sys.stdout.flush()
            tflib.autosummary.save_summaries(summary_log, it)

        if cur_nimg >= tick_start_nimg + 65000:
            cur_tick += 1
            tick_start_nimg = cur_nimg

            if cur_tick % image_snapshot_ticks == 0:
                batch_images_test = sess.run(image_batch_test)
                batch_images_test = misc.adjust_dynamic_range(
                    batch_images_test.astype(np.float32), [0, 255], [-1., 1.])
                recon = sess.run(fake_X_val,
                                 feed_dict={real_test: batch_images_test})
                orin_recon = np.concatenate([batch_images_test, recon], axis=0)
                orin_recon = adjust_pixel_range(orin_recon)
                orin_recon = fuse_images(orin_recon,
                                         row=2,
                                         col=submit_config.batch_size_test)
                # save image results during training, first row is original images and the second row is reconstructed images
                save_image(
                    '%s/iter_%09d.png' % (submit_config.run_dir, cur_nimg),
                    orin_recon)

            if cur_tick % network_snapshot_ticks == 0:
                pkl = os.path.join(submit_config.run_dir,
                                   'network-snapshot-%09d.pkl' % (cur_nimg))
                misc.save_pkl((E, G, D, Gs), pkl)

    misc.save_pkl((E, G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()
示例#16
0
def training_loop(
        submit_config,
        #---------------------------------------------------------------
        # Modified by Deng et al.
        noise_dim=32,
        weight_args={},
        train_stage_args={},
        #---------------------------------------------------------------
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        G_loss_args={},  # Options for generator loss.
        D_loss_args={},  # Options for discriminator loss.
        dataset_args={},  # Options for dataset.load_dataset().
        sched_args={},  # Options for train.TrainingSchedule.
        grid_args={},  # Options for train.setup_snapshot_image_grid().
        metric_arg_list=[],  # Options for MetricGroup.
        tf_config={},  # Options for tflib.init_tf().
        G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
        D_repeats=1,  # How many times the discriminator is trained per G iteration.
        minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
        reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
        total_kimg=15000,  # Total length of the training, measured in thousands of real images.
        mirror_augment=True,  # Enable mirror augment?
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        resume_run_id=87,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=2364,  # Snapshot index to resume training from, None = autodetect.
        resume_kimg=2364,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        resume_time=0.0,
        **_kwargs
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    PI = 3.1415927
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    # Create 3d face reconstruction block
    FaceRender = Face3D()

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            #---------------------------------------------------------------
            # Modified by Deng et al.
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              latent_size=254 + noise_dim,
                              **G_args)
            #---------------------------------------------------------------
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        resolution = tf.placeholder(tf.float32, name='resolution', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)

            #---------------------------------------------------------------
            # Modified by Deng et al.
            G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\
                G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\
                lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\
                drange_net=drange_net,lod_in=lod_in,**train_stage_args)
            #---------------------------------------------------------------

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    #---------------------------------------------------------------
    # Modified by Deng et al.
    restore_weights_and_initialize(train_stage_args)

    print('Setting up snapshot image grid...')
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)

    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3])
    grid_INPUTcoeff = z_to_lambda_mapping(grid_latents)
    grid_INPUTcoeff_w_t = tf.concat(
        [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1)
    with tf.name_scope('FaceRender'):
        grid_render_img, _, _, _ = FaceRender.Reconstruction_Block(
            grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False)
        grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2])
        grid_render_img = process_reals(grid_render_img, lod_in, False,
                                        training_set.dynamic_range, drange_net)

    grid_INPUTcoeff_, grid_renders = tflib.run(
        [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod})
    grid_noise = np.random.randn(np.prod(grid_size), 32)
    grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise],
                                             axis=1)

    grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)
    grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    #---------------------------------------------------------------

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch,
                        resolution: sched.resolution
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    resolution: sched.resolution
                })

            # print('iter')
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                #---------------------------------------------------------------
                # Modified by Deng et al.
                grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            #---------------------------------------------------------------

            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()


#----------------------------------------------------------------------------
示例#17
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=10000.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            #print('Constructing networks...')
            #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            #Gs = G.clone('Gs')
            url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'
            with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
                G, D, Gs = pickle.load(f)
            print('Loading pretrained FFHQ network')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)

    cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' %
                                      resume_kimg) + "  gs://stylegan_out"
    response = subprocess.run(cmd, shell=True)

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                cmd = "gsutil cp " + os.path.join(
                    submit_config.run_dir, 'fakes%06d.png' %
                    (cur_nimg // 1000)) + "  gs://stylegan_out"
                response = subprocess.run(cmd, shell=True)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
示例#18
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device("/gpu:0"):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print("Constructing networks...")
            G = tflib.Network(
                name="G",
                num_inputs=2,  # one for latents and one for labels
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                **G_args)
            D = tflib.Network(
                name="D",
                num_inputs=int(np.log2(training_set.shape[1])) - 1 +
                1,  # +1 for labels :)
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                **D_args)
            Gs = G.clone("Gs")
    G.print_layers()
    D.print_layers()

    print("Building TensorFlow graph...")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[])
        minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = (0.5**tf.div(tf.cast(minibatch_in,
                                       tf.float32), G_smoothing_kimg *
                               1000.0) if G_smoothing_kimg > 0.0 else 0.0)

    G_opt = tflib.Optimizer(name="TrainG",
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name="TrainD",
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow")
            D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow")
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(
                reals,
                mirror_augment,
                training_set.dynamic_range,
                drange_net,
                depth=training_set.resolution_log2 - 1,
            )
            with tf.name_scope("G_loss"):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope("D_loss"):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    # Choose training parameters and configure training ops.
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)

    print("Setting up snapshot image grid...")
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_fakes = Gs.run(
        grid_latents,
        grid_labels,
        is_validation=True,
        minibatch_size=sched.minibatch_size // submit_config.num_gpus,
    )

    print("Setting up run dir...")
    fake_multi_scale_dirs = [
        os.path.join(submit_config.run_dir,
                     str(2**res) + "x" + str(2**res))
        for res in range(2, 2 + len(grid_fakes))
    ]
    misc.save_image_grid(
        grid_reals,
        os.path.join(submit_config.run_dir, "reals.png"),
        drange=training_set.dynamic_range,
        grid_size=grid_size,
    )
    misc.save_image_grids(
        grid_fakes,
        [
            os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg)
            for fake_multi_scale_dir in fake_multi_scale_dirs
        ],
        drange=drange_net,
        grid_size=grid_size,
    )
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print("Training...\n")
    ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg

    # configure the training_set to a proper minibatch size
    training_set.configure(sched.minibatch_size // submit_config.num_gpus)

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop():
            break

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op],
                    {
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch_size
                    },
                )
                cur_nimg += sched.minibatch_size
            tflib.run(
                [G_train_op],
                {
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch_size
                },
            )

        # Perform maintenance tasks once per tick.
        done = cur_nimg >= total_kimg * 1000
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f"
                % (
                    autosummary("Progress/tick", cur_tick),
                    autosummary("Progress/kimg", cur_nimg / 1000.0),
                    autosummary("Progress/minibatch", sched.minibatch_size),
                    dnnlib.util.format_time(
                        autosummary("Timing/total_sec", total_time)),
                    autosummary("Timing/sec_per_tick", tick_time),
                    autosummary("Timing/sec_per_kimg", tick_time / tick_kimg),
                    autosummary("Timing/maintenance_sec", maintenance_time),
                    autosummary("Resources/peak_gpu_mem_gb",
                                peak_gpu_mem_op.eval() / 2**30),
                ))
            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(
                    grid_latents,
                    grid_labels,
                    is_validation=True,
                    minibatch_size=sched.minibatch_size //
                    submit_config.num_gpus,
                )
                misc.save_image_grids(
                    grid_fakes,
                    [
                        os.path.join(fake_multi_scale_dir, "fakes%06d.png" %
                                     (cur_nimg // 1000))
                        for fake_multi_scale_dir in fake_multi_scale_dirs
                    ],
                    drange=drange_net,
                    grid_size=grid_size,
                )
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    "network-snapshot-%06d.pkl" % (cur_nimg // 1000),
                )
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(
                    pkl,
                    run_dir=submit_config.run_dir,
                    num_gpus=submit_config.num_gpus,
                    tf_config=tf_config,
                )

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, "network-final.pkl"))
    summary_log.close()

    ctx.close()