Exemplo n.º 1
0
def run_pickle(submit_config, metric_args, network_pkl, dataset_args, mirror_augment):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on network_pkl "%s"...' % (metric_args.name, network_pkl))
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    metric.run(network_pkl, dataset_args=dataset_args, mirror_augment=mirror_augment, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
Exemplo n.º 2
0
def run_snapshot(submit_config, metric_args, run_id, snapshot):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on run_id %s, snapshot %s...' % (metric_args.name, run_id, snapshot))
    run_dir = misc.locate_run_dir(run_id)
    network_pkl = misc.locate_network_pkl(run_dir, snapshot)
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
Exemplo n.º 3
0
def test_d(submit_config,
           resume_run_id,
           dataset_args,
           tf_config={},
           resume_snapshot=None):

    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    latents_1 = tf.placeholder(tf.float32)
    labels_1 = None

    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    w_1 = Gs.components.mapping.get_output_for(latents_1,
                                               labels_1,
                                               is_validation=True)
    fake_image_1_op = Gs.components.synthesis.get_output_for(
        w_1, is_validation=True, randomize_noise=False)

    reals, labels = training_set.get_minibatch_tf()

    lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])

    reals = process_reals(reals, lod_in, False, training_set.dynamic_range,
                          [-1, 1])

    d_pred_real = D.get_output_for(reals, labels_1)
    d_pred_fake = D.get_output_for(fake_image_1_op, labels_1)

    training_set.configure(1, 0)

    for i in range(15):
        latents_1_val = np.random.randn(1, *G.input_shape[1:])

        # d_pred, fake_image_1 = tflib.run([d_pred_op, fake_image_1_op], feed_dict={latents_1: latents_1_val, lod_in: 0})
        d_pred_real_, d_pred_fake_, real_image = tflib.run(
            [d_pred_real, d_pred_fake, reals],
            feed_dict={
                latents_1: latents_1_val,
                lod_in: 0
            })

        print(d_pred_real_, d_pred_fake_)
        misc.save_mri_image(real_image,
                            os.path.join(submit_config.run_dir,
                                         'real_{}.nii.gz'.format(i)),
                            drange=[-1, 1])
Exemplo n.º 4
0
def run_all_snapshots(submit_config, metric_args, run_id):
    ctx = dnnlib.RunContext(submit_config)
    tflib.init_tf()
    print('Evaluating %s metric on all snapshots of run_id %s...' % (metric_args.name, run_id))
    run_dir = misc.locate_run_dir(run_id)
    network_pkls = misc.list_network_pkls(run_dir)
    metric = dnnlib.util.call_func_by_name(**metric_args)
    print()
    for idx, network_pkl in enumerate(network_pkls):
        ctx.update('', idx, len(network_pkls))
        metric.run(network_pkl, run_dir=run_dir, num_gpus=submit_config.num_gpus)
    print()
    ctx.close()
Exemplo n.º 5
0
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict, network_snapshot: str):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    ctx = dnnlib.RunContext(submit_config, config)

    tfutil.init_tf(config.tf_config)

    with tf.device("/gpu:0"):
        net = util.load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0, noise_augmenter.add_validation_noise_np)
    ctx.close()
Exemplo n.º 6
0
def mixing(submit_config, resume_run_id, tf_config = {}, resume_snapshot=None):

    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
    print('Loading networks from "%s"...' % network_pkl)
    G, D, Gs = misc.load_pkl(network_pkl)

    latents_1_val = np.random.randn(1,*G.input_shape[1:])
    latents_2_val = np.random.randn(1,*G.input_shape[1:])

    # latents_2_val = latents_1_val

    latents_1 = tf.placeholder(tf.float32)
    labels_1 = tf.constant([[0,0,0,0,1,0]])

    latents_2 = tf.placeholder(tf.float32)
    labels_2 = tf.constant([[0,0,0,0,1,0]])

    w_1 = Gs.components.mapping.get_output_for(latents_1, labels_1, is_validation=True)
    w_2 = Gs.components.mapping.get_output_for(latents_2, labels_2, is_validation=True)

    # w_1_val = tflib.run(w_1)
    # w_2_val = tflib.run(w_2)

    fake_image_1_op = Gs.components.synthesis.get_output_for(w_1, is_validation=True, randomize_noise=False)
    fake_image_2_op = Gs.components.synthesis.get_output_for(w_2, is_validation=True, randomize_noise=False)

    fake_image_1 = tflib.run(fake_image_1_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})
    fake_image_2 = tflib.run(fake_image_2_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})

    misc.save_image(fake_image_1[0], os.path.join(submit_config.run_dir,'fake_image_1.png'), drange=[-1,1])
    misc.save_image(fake_image_2[0], os.path.join(submit_config.run_dir,'fake_image_2.png'), drange=[-1,1])

    for i in range(15):
        w_mix = tf.concat([w_1[:,:i],w_2[:,i:]], axis=1)
        fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False)
        fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})
        misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_12_{}.png'.format(i)), drange=[-1,1])

    for i in range(15):
        w_mix = tf.concat([w_2[:,:i],w_1[:,i:]], axis=1)
        fake_mix_op = Gs.components.synthesis.get_output_for(w_mix, is_validation=True, randomize_noise=False)
        fake_mix_image = tflib.run(fake_mix_op, feed_dict={latents_1: latents_1_val, latents_2: latents_2_val})
        misc.save_image(fake_mix_image[0], os.path.join(submit_config.run_dir,'fake_mix_21_{}.png'.format(i)), drange=[-1,1])
Exemplo n.º 7
0
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict,
             network_snapshot: str):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0,
                                noise_augmenter.add_validation_noise_np)
    ctx.close()
Exemplo n.º 8
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    # noinspection PyTypeChecker
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)

    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])

        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised = net_gpu.get_output_for(noisy_input_split[gpu])

            if noise2noise:
                meansq_error = tf.reduce_mean(
                    tf.square(noisy_target_split[gpu] - denoised))
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # ***********************************
    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break

        # Dump training status
        if i % eval_interval == 0:
            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()

            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.png".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.png".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.png".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s eta %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec',
                                   (time_train / eval_interval) *
                                   (iteration_count - i))),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

        # Training epoch
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    # End of training
    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()
Exemplo n.º 9
0
    total_kimg              = 15000,    # Total length of the training, measured in thousands of real images.
    mirror_augment          = False,    # Enable mirror augment?
    drange_net              = [-1,1],   # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks    = 1,        # How often to export image snapshots?
    network_snapshot_ticks  = 1,        # How often to export network snapshots?
    save_tf_graph           = False,    # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms  = False,    # Include weight histograms in the tfevents file?
    resume_run_id           = None,     # Run ID or network pkl to resume training from, None = start from scratch.
    #resume_run_id           = 'results/00011-sgan-custom_datasets-1gpu/network-snapshot-005645.pkl',
    resume_snapshot         = None,     # Snapshot index to resume training from, None = autodetect.
    resume_kimg             = 0.0       # Assumed training progress at the beginning. Affects reporting and training schedule.
    #resume_kimg             = 5645,
    resume_time             = 0.0):     # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
Exemplo n.º 10
0
def joint_train(
    submit_config,
    opt,
    metric_arg_list,
    sched_args              = {},       # 训练计划设置。
    grid_args               = {},       # setup_snapshot_image_grid()相关设置。
    dataset_args            = {},       # 数据集设置。
    total_kimg              = 15000,    # 训练的总长度,以成千上万个真实图像为统计。
    drange_net              = [-1,1],   # 将图像数据馈送到网络时使用的动态范围。
    image_snapshot_ticks    = 1,        # 多久导出一次图像快照?
    network_snapshot_ticks  = 10,       # 多久导出一次网络模型存储?
    D_repeats               = 1,        # G每迭代一次训练判别器多少次。
    minibatch_repeats       = 4,        # 调整训练参数前要运行的minibatch的数量。
    mirror_augment          = False,    # 启用镜像增强?
    reset_opt_for_new_lod   = True,     # 引入新层时是否重置优化器内部状态(例如Adam时刻)?
    save_tf_graph           = False,    # 在tfevents文件中包含完整的TensorFlow计算图吗?
    save_weight_histograms  = False,    # 在tfevents文件中包括权重直方图?
    resume_run_id           = None,     # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。
    resume_snapshot         = None,     # 要从哪恢复训练的快照的索引,None = 自动检测。
    resume_kimg             = 0.0,      # 在训练开始时给定当前训练进度。影响报告和训练计划。
    resume_time             = 0.0,     # 在训练开始时给定统计时间。影响报告。
    *args,
    **kwargs
    ):

    output_dir = opt.output_dir

    graph_kwargs = util.set_graph_kwargs(opt)

    graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util')
    constants = importlib.import_module('graphs.' + opt.model + '.constants')

    model = graphs.find_model_using_name(opt.model, opt.transform)
    g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs)
    g.initialize_graph()

    # create training samples
    #num_samples = opt.num_samples
    # if opt.model == 'biggan' and opt.biggan.category is not None:
    #     graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category)
    # else:
    #     graph_inputs = graph_util.graph_input(g, num_samples, seed=0)



    w_snapshot_ticks = opt.model_save_freq

    ctx = dnnlib.RunContext(submit_config, train)
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)
    
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    
    # 设置快照图像网格
    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args)
    sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
    grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
    # 建立运行目录
    print('Setting up run dir...')
    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        g.G.setup_weight_histograms(); g.D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)
    # 训练
    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    loss_values = []
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # 选择训练参数并配置训练操作。
        sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # 进行训练。
        for _mb_repeat in range(minibatch_repeats):
            alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE)
            if not isinstance(alpha_for_graph, list):
                alpha_for_graph = [alpha_for_graph]
                alpha_for_target = [alpha_for_target]
            for ag, at in zip(alpha_for_graph, alpha_for_target):
                feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0)
                out_zs = g.sess.run(g.outputs_orig, feed_dict_out)

                target_fn, mask_out = g.get_target_np(out_zs, at)
                feed_dict = feed_dict_out
                feed_dict[g.alpha] = ag
                feed_dict[g.target] = target_fn
                feed_dict[g.mask] = mask_out
                feed_dict[g.lod_in] = sched.lod
                feed_dict[g.lrate_in] = sched.D_lrate
                feed_dict[g.minibatch_in] = sched.minibatch
                curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict)
                loss_values.append(curr_loss)
            
            cur_nimg += sched.minibatch
            #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
            #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # 每个tick执行一次维护任务。
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # 报告进度。
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # 保存快照。
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((g.G, g.D, g.Gs), pkl)
                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)
            if cur_tick % w_snapshot_ticks == 0 or done:
                g.saver.save(g.sess, './{}/model_{}.ckpt'.format(
                    output_dir, (cur_nimg // 1000)),
                    write_meta_graph=False, write_state=False)

            # 更新摘要和RunContext。
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # 保存最终结果。
    misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()

    loss_values = np.array(loss_values)
    np.save('./{}/loss_values.npy'.format(output_dir), loss_values)
    f, ax  = plt.subplots(figsize=(10, 4))
    ax.plot(loss_values)
    f.savefig('./{}/loss_values.png'.format(output_dir))
Exemplo n.º 11
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0,
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device("/gpu:0"):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print("Constructing networks...")
            G = tflib.Network(
                name="G",
                num_inputs=2,  # one for latents and one for labels
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                **G_args)
            D = tflib.Network(
                name="D",
                num_inputs=int(np.log2(training_set.shape[1])) - 1 +
                1,  # +1 for labels :)
                num_channels=training_set.shape[0],
                resolution=training_set.shape[1],
                label_size=training_set.label_size,
                **D_args)
            Gs = G.clone("Gs")
    G.print_layers()
    D.print_layers()

    print("Building TensorFlow graph...")
    with tf.name_scope("Inputs"), tf.device("/cpu:0"):
        lrate_in = tf.placeholder(tf.float32, name="lrate_in", shape=[])
        minibatch_in = tf.placeholder(tf.int32, name="minibatch_in", shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = (0.5**tf.div(tf.cast(minibatch_in,
                                       tf.float32), G_smoothing_kimg *
                               1000.0) if G_smoothing_kimg > 0.0 else 0.0)

    G_opt = tflib.Optimizer(name="TrainG",
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name="TrainD",
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope("GPU%d" % gpu), tf.device("/gpu:%d" % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + "_shadow")
            D_gpu = D if gpu == 0 else D.clone(D.name + "_shadow")
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(
                reals,
                mirror_augment,
                training_set.dynamic_range,
                drange_net,
                depth=training_set.resolution_log2 - 1,
            )
            with tf.name_scope("G_loss"):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope("D_loss"):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device("/gpu:0"):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    # Choose training parameters and configure training ops.
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              **sched_args)

    print("Setting up snapshot image grid...")
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_fakes = Gs.run(
        grid_latents,
        grid_labels,
        is_validation=True,
        minibatch_size=sched.minibatch_size // submit_config.num_gpus,
    )

    print("Setting up run dir...")
    fake_multi_scale_dirs = [
        os.path.join(submit_config.run_dir,
                     str(2**res) + "x" + str(2**res))
        for res in range(2, 2 + len(grid_fakes))
    ]
    misc.save_image_grid(
        grid_reals,
        os.path.join(submit_config.run_dir, "reals.png"),
        drange=training_set.dynamic_range,
        grid_size=grid_size,
    )
    misc.save_image_grids(
        grid_fakes,
        [
            os.path.join(fake_multi_scale_dir, "fakes%06d.png" % resume_kimg)
            for fake_multi_scale_dir in fake_multi_scale_dirs
        ],
        drange=drange_net,
        grid_size=grid_size,
    )
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print("Training...\n")
    ctx.update("", cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg

    # configure the training_set to a proper minibatch size
    training_set.configure(sched.minibatch_size // submit_config.num_gpus)

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop():
            break

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op],
                    {
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch_size
                    },
                )
                cur_nimg += sched.minibatch_size
            tflib.run(
                [G_train_op],
                {
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch_size
                },
            )

        # Perform maintenance tasks once per tick.
        done = cur_nimg >= total_kimg * 1000
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                "tick %-5d kimg %-8.1f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f"
                % (
                    autosummary("Progress/tick", cur_tick),
                    autosummary("Progress/kimg", cur_nimg / 1000.0),
                    autosummary("Progress/minibatch", sched.minibatch_size),
                    dnnlib.util.format_time(
                        autosummary("Timing/total_sec", total_time)),
                    autosummary("Timing/sec_per_tick", tick_time),
                    autosummary("Timing/sec_per_kimg", tick_time / tick_kimg),
                    autosummary("Timing/maintenance_sec", maintenance_time),
                    autosummary("Resources/peak_gpu_mem_gb",
                                peak_gpu_mem_op.eval() / 2**30),
                ))
            autosummary("Timing/total_hours", total_time / (60.0 * 60.0))
            autosummary("Timing/total_days", total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(
                    grid_latents,
                    grid_labels,
                    is_validation=True,
                    minibatch_size=sched.minibatch_size //
                    submit_config.num_gpus,
                )
                misc.save_image_grids(
                    grid_fakes,
                    [
                        os.path.join(fake_multi_scale_dir, "fakes%06d.png" %
                                     (cur_nimg // 1000))
                        for fake_multi_scale_dir in fake_multi_scale_dirs
                    ],
                    drange=drange_net,
                    grid_size=grid_size,
                )
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    "network-snapshot-%06d.pkl" % (cur_nimg // 1000),
                )
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(
                    pkl,
                    run_dir=submit_config.run_dir,
                    num_gpus=submit_config.num_gpus,
                    tf_config=tf_config,
                )

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update(cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, "network-final.pkl"))
    summary_log.close()

    ctx.close()
Exemplo n.º 12
0
def training_loop(
    submit_config,
    G_args                  = {},       # 生成网络的设置。
    D_args                  = {},       # 判别网络的设置。
    G_opt_args              = {},       # 生成网络优化器设置。
    D_opt_args              = {},       # 判别网络优化器设置。
    G_loss_args             = {},       # 生成损失设置。
    D_loss_args             = {},       # 判别损失设置。
    dataset_args            = {},       # 数据集设置。
    sched_args              = {},       # 训练计划设置。
    grid_args               = {},       # setup_snapshot_image_grid()相关设置。
    metric_arg_list         = [],       # 指标方法设置。
    tf_config               = {},       # tflib.init_tf()相关设置。
    G_smoothing_kimg        = 10.0,     # 生成器权重的运行平均值的半衰期。
    D_repeats               = 1,        # G每迭代一次训练判别器多少次。
    minibatch_repeats       = 4,        # 调整训练参数前要运行的minibatch的数量。
    reset_opt_for_new_lod   = True,     # 引入新层时是否重置优化器内部状态(例如Adam时刻)?
    total_kimg              = 15000,    # 训练的总长度,以成千上万个真实图像为统计。
    mirror_augment          = False,    # 启用镜像增强?
    drange_net              = [-1,1],   # 将图像数据馈送到网络时使用的动态范围。
    image_snapshot_ticks    = 1,        # 多久导出一次图像快照?
    network_snapshot_ticks  = 10,       # 多久导出一次网络模型存储?
    save_tf_graph           = False,    # 在tfevents文件中包含完整的TensorFlow计算图吗?
    save_weight_histograms  = False,    # 在tfevents文件中包括权重直方图?
    resume_run_id           = None,     # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。
    resume_snapshot         = None,     # 要从哪恢复训练的快照的索引,None = 自动检测。
    resume_kimg             = 0.0,      # 在训练开始时给定当前训练进度。影响报告和训练计划。
    resume_time             = 0.0):     # 在训练开始时给定统计时间。影响报告。

    # 初始化dnnlib和TensorFlow。
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # 载入训练集。
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)

    # 构建网络。
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            Gs = G.clone('Gs')
    G.print_layers(); D.print_layers()
    # 构建计算图与优化器
    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in          = tf.placeholder(tf.float32, name='lod_in', shape=[])
        # tf.placeholder:可以理解为形参,用于定于过程,具体执行时再赋具体的值。
        lrate_in        = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in    = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta         = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG', learning_rate=lrate_in, **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD', learning_rate=lrate_in, **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment, training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals, labels=labels, **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    # 设置快照图像网格
    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
    grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
    # 建立运行目录
    print('Setting up run dir...')
    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms(); D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)
    # 训练
    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # 选择训练参数并配置训练操作。
        sched = training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # 进行训练。
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
                cur_nimg += sched.minibatch
            tflib.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # 每个tick执行一次维护任务。
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # 报告进度。
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # 保存快照。
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)

            # 更新摘要和RunContext。
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # 保存最终结果。
    misc.save_pkl((G, D, Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Exemplo n.º 13
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=10,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)

    # ajay - move init to after graph creation?
    tflib.init_tf(tf_config)

    # Load training set.
    print('ajay - config data dir', config.data_dir)
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        num_hosts=hvd.size(),
                                        index=hvd.rank(),
                                        **dataset_args)

    # Construct networks.
    print('Constructing networks...')
    G = tflib.Network('G',
                      num_channels=training_set.shape[0],
                      resolution=training_set.shape[1],
                      label_size=training_set.label_size,
                      **G_args)
    D = tflib.Network('D',
                      num_channels=training_set.shape[0],
                      resolution=training_set.shape[1],
                      label_size=training_set.label_size,
                      **D_args)
    Gs = G.clone('Gs')
    # with tf.device('/gpu:0'):
    #     if resume_run_id is not None:
    #         network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
    #         print('Loading networks from "%s"...' % network_pkl)
    #         G, D, Gs = misc.load_pkl(network_pkl)
    #     else:
    #         print('Constructing networks...')
    #         G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
    #         D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
    #         Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tf.train.AdamOptimizer(learning_rate=lrate_in,
                                   beta1=0.0,
                                   beta2=0.99,
                                   epsilon=1e-8)
    G_opt = hvd.DistributedOptimizer(G_opt)
    D_opt = tf.train.AdamOptimizer(learning_rate=lrate_in,
                                   beta1=0.0,
                                   beta2=0.99,
                                   epsilon=1e-8)
    D_opt = hvd.DistributedOptimizer(D_opt)
    G_gpu = G
    D_gpu = D
    lod_assign_ops = [
        tf.assign(G_gpu.find_var('lod'), lod_in),
        tf.assign(D_gpu.find_var('lod'), lod_in)
    ]
    # ajay - check if unique minibatch is guaranteed i.e sharding is done right!
    reals, labels = training_set.get_minibatch_tf()
    reals = process_reals(reals, lod_in, mirror_augment,
                          training_set.dynamic_range, drange_net)
    with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
        G_loss = dnnlib.util.call_func_by_name(G=G_gpu,
                                               D=D_gpu,
                                               opt=G_opt,
                                               training_set=training_set,
                                               minibatch_size=minibatch_split,
                                               **G_loss_args)
    with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
        D_loss = dnnlib.util.call_func_by_name(G=G_gpu,
                                               D=D_gpu,
                                               opt=D_opt,
                                               training_set=training_set,
                                               minibatch_size=minibatch_split,
                                               reals=reals,
                                               labels=labels,
                                               **D_loss_args)
    G_grads = G_opt.compute_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
    D_grads = D_opt.compute_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)

    G_train_op = G_opt.apply_gradients(G_grads)
    D_train_op = D_opt.apply_gradients(D_grads)
    # Horovod
    init_op = tf.initialize_all_variables()
    bcast_op = hvd.broadcast_global_variables(0)
    # ajay
    tf.get_default_session().run([init_op])
    tflib.run([bcast_op])

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    # todo: ajay - note num_gpus need to change to hvd size when going multi-node
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    if hvd.rank() == 0:
        print('Setting up run dir...')
        misc.save_image_grid(grid_reals,
                             os.path.join(submit_config.run_dir, 'reals.png'),
                             drange=training_set.dynamic_range,
                             grid_size=grid_size)
        misc.save_image_grid(grid_fakes,
                             os.path.join(submit_config.run_dir,
                                          'fakes%06d.png' % resume_kimg),
                             drange=drange_net,
                             grid_size=grid_size)
        summary_log = tf.summary.FileWriter(submit_config.run_dir)
        if save_tf_graph:
            summary_log.add_graph(tf.get_default_graph())
        if save_weight_histograms:
            G.setup_weight_histograms()
            D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0

    while cur_nimg < (total_kimg * 1000):
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        # todo: ajay - find a way to manually resetoptimizer
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                tflib.assert_tf_initialized()
                G_opt_reset_op = [var.initializer for var in G_opt.variables()]
                D_opt_reset_op = [var.initializer for var in D_opt.variables()]
                tflib.run(G_opt_reset_op)
                tflib.run(D_opt_reset_op)
                # G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
        prev_lod = sched.lod
        # grp_train_op = tf.group(D_train_op, [Gs_update_op])
        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch  #// submit_config.num_gpus
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= (total_kimg * 1000))
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            # ajay
            #ajay mod
            if hvd.rank() == 0:
                print(
                    'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f '
                    %
                    (autosummary('Progress/tick', cur_tick),
                     autosummary('Progress/kimg', cur_nimg / 1000.0),
                     autosummary('Progress/lod', sched.lod),
                     autosummary('Progress/minibatch', sched.minibatch),
                     dnnlib.util.format_time(
                         autosummary('Timing/total_sec', total_time)),
                     autosummary('Timing/sec_per_tick', tick_time),
                     autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                     autosummary('Timing/maintenance_sec', maintenance_time)))
                autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
                autosummary('Timing/total_days',
                            total_time / (24.0 * 60.0 * 60.0))
                # print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                #     autosummary('Progress/tick', cur_tick),
                #     autosummary('Progress/kimg', cur_nimg / 1000.0),
                #     autosummary('Progress/lod', sched.lod),
                #     autosummary('Progress/minibatch', sched.minibatch),
                #     dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                #     autosummary('Timing/sec_per_tick', tick_time),
                #     autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                #     autosummary('Timing/maintenance_sec', maintenance_time),
                #     autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
                # autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
                # autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                # ajay - note modifying to 1 for eval
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=1,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            if hvd.rank() == 0:
                tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    if hvd.rank() == 0:
        misc.save_pkl((G, D, Gs),
                      os.path.join(submit_config.run_dir, 'network-final.pkl'))
        summary_log.close()

    ctx.close()
Exemplo n.º 14
0
def training_loop(
    submit_config,
    HP_args={},  # Options for the Hessian Penalty.
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    interp_snapshot_ticks=20,  # How often to generate interpolation visualizations in TensorBoard?
    network_snapshot_ticks=5,  # How often to export network snapshots?
    network_metric_ticks=5,  # How often to evaluate network snapshots on specified metrics?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=0.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Create a copy of dataset_args for running the metrics:
    metrics_dataset_args = deepcopy(dataset_args)
    metrics_dataset_args.shuffle_mb = 0

    print('Saving interp videos every %s ticks' % interp_snapshot_ticks)
    print('Saving network snapshot every %s ticks' % network_snapshot_ticks)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **G_args)
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    # G.print_layers(); D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    # The loss weighting of the Hessian Penalty can be dynamic over training, so we need to use a placeholder:
    hessian_weight = tf.placeholder(tf.float32,
                                    name='hessian_weight',
                                    shape=[])

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    reg_ops = [
    ]  # Returning the values of the Hessian Penalty/ InfoGAN losses so they can be registered in TensorBoard
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss, G_hessian_penalty = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    hp_lambda=hessian_weight,
                    HP_args=HP_args,
                    gpu_ix=gpu,
                    lod_in=lod_in,
                    max_lod=training_set.resolution_log2,
                    **G_loss_args)
                if HP_args.hp_lambda > 0:
                    reg_ops += [G_hessian_penalty]
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss, mutual_info = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    gpu_ix=gpu,
                    infogan_nz=D_args.infogan_nz,
                    **D_loss_args)
                # print([name for name in D_gpu.trainables.keys()])
                # gps = [weight for name, weight in G_gpu.trainables.items()][0]
                # dps = [weight for name, weight in D_gpu.trainables.items() if 'Q_Encoder' in name][0]
                # gg = autosummary('Loss/G_info_grad', tf.reduce_sum(tf.gradients(mutual_info, gps)[0]**2))
                # dg = autosummary('Loss/D_info_grad', tf.reduce_sum(tf.gradients(mutual_info, dps)[0]**2))
                # reg_ops.extend([dg, gg, dps, gps])
                if G_args.infogan_lambda > 0 or D_args.infogan_lambda > 0:
                    reg_ops += [mutual_info]
            # Note, even though we are adding mutual_info loss here, the only time the loss is non-zero
            # is when infogan_lambda > 0 (in Hessian Penalty experiments, we always set it to zero):
            G_opt.register_gradients(
                G_loss + G_args.infogan_lambda * mutual_info, G_gpu.trainables)
            D_opt.register_gradients(
                tf.reduce_mean(D_loss) + D_args.infogan_lambda * mutual_info,
                D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up snapshot interpolation...')
    nz = G.input_shapes[0][1]
    interp_steps = 24  # Number of frames in the visualization
    interp_batch_size = 8  # Number of gifs per row of visualization
    interp_z = vis_tools.sample_interp_zs(nz, interp_batch_size, interp_steps)
    interp_labels = np.zeros(
        [interp_steps * interp_batch_size * nz, training_set.label_size],
        dtype=training_set.label_dtype)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    summary_log.add_summary(
        build_image_summary(os.path.join(submit_config.run_dir, 'reals.png'),
                            'samples/real'), 0)
    summary_log.add_summary(
        build_image_summary(
            os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg),
            'samples/Gs'), resume_kimg)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    if interp_snapshot_ticks != -1 and interp_snapshot_ticks < 9999:
        print('Generating initial interpolations...')
        vis_tools.make_and_save_interpolation_gifs(
            Gs,
            interp_z,
            interp_labels,
            minibatch_size=sched.minibatch // submit_config.num_gpus,
            interp_steps=interp_steps,
            interp_batch_size=interp_batch_size,
            cur_kimg=resume_kimg,
            summary_log=summary_log)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    num_G_grad_steps = 0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop():
            break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            cur_hessian_weight = get_current_hessian_penalty_loss_weight(
                HP_args.hp_lambda, HP_args.hp_start_nimg, cur_nimg,
                HP_args.warmup_nimg)
            tflib.run(
                [G_train_op] + reg_ops, {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    hessian_weight: cur_hessian_weight
                })
            num_G_grad_steps += 1

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d hessian_weight %s time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   autosummary('Progress/hessian_weight', cur_hessian_weight),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
            autosummary('Progress/G_grad_steps', num_G_grad_steps)

            # Save snapshots and fake image samples (for both Gs and G):
            if cur_tick % image_snapshot_ticks == 0 or done:
                iter = (cur_nimg // 1000)
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes_inst = G.run(grid_latents,
                                        grid_labels,
                                        is_validation=True,
                                        minibatch_size=sched.minibatch //
                                        submit_config.num_gpus)
                fake_path = os.path.join(submit_config.run_dir,
                                         'fakes%06d.png' % iter)
                ifake_path = os.path.join(submit_config.run_dir,
                                          'ifakes%06d.png' % iter)
                misc.save_image_grid(grid_fakes,
                                     fake_path,
                                     drange=drange_net,
                                     grid_size=grid_size)
                misc.save_image_grid(grid_fakes_inst,
                                     ifake_path,
                                     drange=drange_net,
                                     grid_size=grid_size)
                summary_log.add_summary(
                    build_image_summary(fake_path, 'samples/Gs'), iter)
                summary_log.add_summary(
                    build_image_summary(ifake_path, 'samples/G'), iter)

            # Generate/Save Interpolation Visualizations:
            if interp_snapshot_ticks != -1 and cur_tick % interp_snapshot_ticks == 0:
                vis_tools.make_and_save_interpolation_gifs(
                    Gs,
                    interp_z,
                    interp_labels,
                    minibatch_size=sched.minibatch // submit_config.num_gpus,
                    interp_steps=interp_steps,
                    interp_batch_size=interp_batch_size,
                    cur_kimg=cur_nimg // 1000,
                    summary_log=summary_log)

            # Save snapshot and run metrics:
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                if cur_tick % network_metric_ticks == 0 or done or cur_tick == 1:
                    metrics.run(pkl,
                                dataset_args=metrics_dataset_args,
                                mirror_augment=mirror_augment,
                                num_gpus=submit_config.num_gpus,
                                tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir,
                               'network-snapshot-%06d.pkl' % total_kimg))
    summary_log.close()

    ctx.close()
Exemplo n.º 15
0
def training_loop(
        submit_config,
        #---------------------------------------------------------------
        # Modified by Deng et al.
        noise_dim=32,
        weight_args={},
        train_stage_args={},
        #---------------------------------------------------------------
        G_args={},  # Options for generator network.
        D_args={},  # Options for discriminator network.
        G_opt_args={},  # Options for generator optimizer.
        D_opt_args={},  # Options for discriminator optimizer.
        G_loss_args={},  # Options for generator loss.
        D_loss_args={},  # Options for discriminator loss.
        dataset_args={},  # Options for dataset.load_dataset().
        sched_args={},  # Options for train.TrainingSchedule.
        grid_args={},  # Options for train.setup_snapshot_image_grid().
        metric_arg_list=[],  # Options for MetricGroup.
        tf_config={},  # Options for tflib.init_tf().
        G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
        D_repeats=1,  # How many times the discriminator is trained per G iteration.
        minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
        reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
        total_kimg=15000,  # Total length of the training, measured in thousands of real images.
        mirror_augment=True,  # Enable mirror augment?
        drange_net=[
            -1, 1
        ],  # Dynamic range used when feeding image data to the networks.
        image_snapshot_ticks=1,  # How often to export image snapshots?
        network_snapshot_ticks=10,  # How often to export network snapshots?
        save_tf_graph=True,  # Include full TensorFlow computation graph in the tfevents file?
        save_weight_histograms=False,  # Include weight histograms in the tfevents file?
        resume_run_id=87,  # Run ID or network pkl to resume training from, None = start from scratch.
        resume_snapshot=2364,  # Snapshot index to resume training from, None = autodetect.
        resume_kimg=2364,  # Assumed training progress at the beginning. Affects reporting and training schedule.
        resume_time=0.0,
        **_kwargs
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    PI = 3.1415927
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)
    # Create 3d face reconstruction block
    FaceRender = Face3D()

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            print('Constructing networks...')
            #---------------------------------------------------------------
            # Modified by Deng et al.
            G = tflib.Network('G',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              latent_size=254 + noise_dim,
                              **G_args)
            #---------------------------------------------------------------
            D = tflib.Network('D',
                              num_channels=training_set.shape[0],
                              resolution=training_set.shape[1],
                              label_size=training_set.label_size,
                              **D_args)
            Gs = G.clone('Gs')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        resolution = tf.placeholder(tf.float32, name='resolution', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % (gpu)):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)

            #---------------------------------------------------------------
            # Modified by Deng et al.
            G_loss,D_loss = dnnlib.util.call_func_by_name(FaceRender=FaceRender,noise_dim=noise_dim,weight_args=weight_args,\
                G_gpu=G_gpu,D_gpu=D_gpu,G_opt=G_opt,D_opt=D_opt,training_set=training_set,G_loss_args=G_loss_args,D_loss_args=D_loss_args,\
                lod_assign_ops=lod_assign_ops,reals=reals,labels=labels,minibatch_split=minibatch_split,resolution=resolution,\
                drange_net=drange_net,lod_in=lod_in,**train_stage_args)
            #---------------------------------------------------------------

            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    #---------------------------------------------------------------
    # Modified by Deng et al.
    restore_weights_and_initialize(train_stage_args)

    print('Setting up snapshot image grid...')
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)

    grid_size, grid_reals, grid_labels = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    grid_latents = tf.random_normal([np.prod(grid_size), 128 + 32 + 16 + 3])
    grid_INPUTcoeff = z_to_lambda_mapping(grid_latents)
    grid_INPUTcoeff_w_t = tf.concat(
        [grid_INPUTcoeff, tf.zeros([np.prod(grid_size), 3])], axis=1)
    with tf.name_scope('FaceRender'):
        grid_render_img, _, _, _ = FaceRender.Reconstruction_Block(
            grid_INPUTcoeff_w_t, 256, np.prod(grid_size), progressive=False)
        grid_render_img = tf.transpose(grid_render_img, perm=[0, 3, 1, 2])
        grid_render_img = process_reals(grid_render_img, lod_in, False,
                                        training_set.dynamic_range, drange_net)

    grid_INPUTcoeff_, grid_renders = tflib.run(
        [grid_INPUTcoeff, grid_render_img], {lod_in: sched.lod})
    grid_noise = np.random.randn(np.prod(grid_size), 32)
    grid_INPUTcoeff_w_noise = np.concatenate([grid_INPUTcoeff_, grid_noise],
                                             axis=1)

    grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)
    grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    #---------------------------------------------------------------

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0

    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch,
                        resolution: sched.resolution
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch,
                    resolution: sched.resolution
                })

            # print('iter')
        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                #---------------------------------------------------------------
                # Modified by Deng et al.
                grid_fakes = Gs.run(grid_INPUTcoeff_w_noise,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                grid_fakes = np.concatenate([grid_fakes, grid_renders], axis=3)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
            #---------------------------------------------------------------

            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()


#----------------------------------------------------------------------------
Exemplo n.º 16
0
def training_loop(
    submit_config,
    G_args={},  # Options for generator network.
    D_args={},  # Options for discriminator network.
    G_opt_args={},  # Options for generator optimizer.
    D_opt_args={},  # Options for discriminator optimizer.
    G_loss_args={},  # Options for generator loss.
    D_loss_args={},  # Options for discriminator loss.
    dataset_args={},  # Options for dataset.load_dataset().
    sched_args={},  # Options for train.TrainingSchedule.
    grid_args={},  # Options for train.setup_snapshot_image_grid().
    metric_arg_list=[],  # Options for MetricGroup.
    tf_config={},  # Options for tflib.init_tf().
    G_smoothing_kimg=10.0,  # Half-life of the running average of generator weights.
    D_repeats=1,  # How many times the discriminator is trained per G iteration.
    minibatch_repeats=4,  # Number of minibatches to run before adjusting training parameters.
    reset_opt_for_new_lod=True,  # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
    total_kimg=15000,  # Total length of the training, measured in thousands of real images.
    mirror_augment=False,  # Enable mirror augment?
    drange_net=[
        -1, 1
    ],  # Dynamic range used when feeding image data to the networks.
    image_snapshot_ticks=1,  # How often to export image snapshots?
    network_snapshot_ticks=10,  # How often to export network snapshots?
    save_tf_graph=False,  # Include full TensorFlow computation graph in the tfevents file?
    save_weight_histograms=False,  # Include weight histograms in the tfevents file?
    resume_run_id=None,  # Run ID or network pkl to resume training from, None = start from scratch.
    resume_snapshot=None,  # Snapshot index to resume training from, None = autodetect.
    resume_kimg=10000.0,  # Assumed training progress at the beginning. Affects reporting and training schedule.
    resume_time=0.0
):  # Assumed wallclock time at the beginning. Affects reporting.

    # Initialize dnnlib and TensorFlow.
    ctx = dnnlib.RunContext(submit_config, train)
    tflib.init_tf(tf_config)

    # Load training set.
    training_set = dataset.load_dataset(data_dir=config.data_dir,
                                        verbose=True,
                                        **dataset_args)

    # Construct networks.
    with tf.device('/gpu:0'):
        if resume_run_id is not None:
            network_pkl = misc.locate_network_pkl(resume_run_id,
                                                  resume_snapshot)
            print('Loading networks from "%s"...' % network_pkl)
            G, D, Gs = misc.load_pkl(network_pkl)
        else:
            #print('Constructing networks...')
            #G = tflib.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **G_args)
            #D = tflib.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **D_args)
            #Gs = G.clone('Gs')
            url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ'
            with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
                G, D, Gs = pickle.load(f)
            print('Loading pretrained FFHQ network')
    G.print_layers()
    D.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device('/cpu:0'):
        lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
        lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
        minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
        minibatch_split = minibatch_in // submit_config.num_gpus
        Gs_beta = 0.5**tf.div(tf.cast(minibatch_in,
                                      tf.float32), G_smoothing_kimg *
                              1000.0) if G_smoothing_kimg > 0.0 else 0.0

    G_opt = tflib.Optimizer(name='TrainG',
                            learning_rate=lrate_in,
                            **G_opt_args)
    D_opt = tflib.Optimizer(name='TrainD',
                            learning_rate=lrate_in,
                            **D_opt_args)
    for gpu in range(submit_config.num_gpus):
        with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
            G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
            D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
            lod_assign_ops = [
                tf.assign(G_gpu.find_var('lod'), lod_in),
                tf.assign(D_gpu.find_var('lod'), lod_in)
            ]
            reals, labels = training_set.get_minibatch_tf()
            reals = process_reals(reals, lod_in, mirror_augment,
                                  training_set.dynamic_range, drange_net)
            with tf.name_scope('G_loss'), tf.control_dependencies(
                    lod_assign_ops):
                G_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=G_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    **G_loss_args)
            with tf.name_scope('D_loss'), tf.control_dependencies(
                    lod_assign_ops):
                D_loss = dnnlib.util.call_func_by_name(
                    G=G_gpu,
                    D=D_gpu,
                    opt=D_opt,
                    training_set=training_set,
                    minibatch_size=minibatch_split,
                    reals=reals,
                    labels=labels,
                    **D_loss_args)
            G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
            D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
    G_train_op = G_opt.apply_updates()
    D_train_op = D_opt.apply_updates()

    Gs_update_op = Gs.setup_as_moving_average_of(G, beta=Gs_beta)
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)

    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(
        G, training_set, **grid_args)
    sched = training_schedule(cur_nimg=total_kimg * 1000,
                              training_set=training_set,
                              num_gpus=submit_config.num_gpus,
                              **sched_args)
    grid_fakes = Gs.run(grid_latents,
                        grid_labels,
                        is_validation=True,
                        minibatch_size=sched.minibatch //
                        submit_config.num_gpus)

    print('Setting up run dir...')
    misc.save_image_grid(grid_reals,
                         os.path.join(submit_config.run_dir, 'reals.png'),
                         drange=training_set.dynamic_range,
                         grid_size=grid_size)
    misc.save_image_grid(grid_fakes,
                         os.path.join(submit_config.run_dir,
                                      'fakes%06d.png' % resume_kimg),
                         drange=drange_net,
                         grid_size=grid_size)

    cmd = "gsutil cp " + os.path.join(submit_config.run_dir, 'fakes%06d.png' %
                                      resume_kimg) + "  gs://stylegan_out"
    response = subprocess.run(cmd, shell=True)

    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        G.setup_weight_histograms()
        D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)

    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # Choose training parameters and configure training ops.
        sched = training_schedule(cur_nimg=cur_nimg,
                                  training_set=training_set,
                                  num_gpus=submit_config.num_gpus,
                                  **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus,
                               sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(
                    sched.lod) != np.ceil(prev_lod):
                G_opt.reset_optimizer_state()
                D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # Run training ops.
        for _mb_repeat in range(minibatch_repeats):
            for _D_repeat in range(D_repeats):
                tflib.run(
                    [D_train_op, Gs_update_op], {
                        lod_in: sched.lod,
                        lrate_in: sched.D_lrate,
                        minibatch_in: sched.minibatch
                    })
                cur_nimg += sched.minibatch
            tflib.run(
                [G_train_op], {
                    lod_in: sched.lod,
                    lrate_in: sched.G_lrate,
                    minibatch_in: sched.minibatch
                })

        # Perform maintenance tasks once per tick.
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # Report progress.
            print(
                'tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f'
                % (autosummary('Progress/tick', cur_tick),
                   autosummary('Progress/kimg', cur_nimg / 1000.0),
                   autosummary('Progress/lod', sched.lod),
                   autosummary('Progress/minibatch', sched.minibatch),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', total_time)),
                   autosummary('Timing/sec_per_tick', tick_time),
                   autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                   autosummary('Timing/maintenance_sec', maintenance_time),
                   autosummary('Resources/peak_gpu_mem_gb',
                               peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # Save snapshots.
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = Gs.run(grid_latents,
                                    grid_labels,
                                    is_validation=True,
                                    minibatch_size=sched.minibatch //
                                    submit_config.num_gpus)
                misc.save_image_grid(grid_fakes,
                                     os.path.join(
                                         submit_config.run_dir,
                                         'fakes%06d.png' % (cur_nimg // 1000)),
                                     drange=drange_net,
                                     grid_size=grid_size)
                cmd = "gsutil cp " + os.path.join(
                    submit_config.run_dir, 'fakes%06d.png' %
                    (cur_nimg // 1000)) + "  gs://stylegan_out"
                response = subprocess.run(cmd, shell=True)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(
                    submit_config.run_dir,
                    'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((G, D, Gs), pkl)
                metrics.run(pkl,
                            run_dir=submit_config.run_dir,
                            num_gpus=submit_config.num_gpus,
                            tf_config=tf_config)

            # Update summaries and RunContext.
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod,
                       cur_epoch=cur_nimg // 1000,
                       max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # Write final results.
    misc.save_pkl((G, D, Gs),
                  os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()
Exemplo n.º 17
0
def train(submit_config: dnnlib.SubmitConfig, iteration_count: int,
          eval_interval: int, minibatch_size: int, learning_rate: float,
          ramp_down_perc: float, noise: dict, validation_config: dict,
          train_tfrecords: str, noise2noise: bool):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**validation_config)

    # Create a run context (hides low level details, exposes simple API to manage the run)
    ctx = dnnlib.RunContext(submit_config, config)

    # Initialize TensorFlow graph and session using good default settings
    tfutil.init_tf(config.tf_config)

    dataset_iter = create_dataset(train_tfrecords, minibatch_size,
                                  noise_augmenter.add_train_noise_tf)
    # Construct the network using the Network helper class and a function defined in config.net_config
    with tf.device("/gpu:0"):
        net = tflib.Network(**config.net_config)

    # Optionally print layer information
    net.print_layers()

    print('Building TensorFlow graph...')
    with tf.name_scope('Inputs'), tf.device("/cpu:0"):
        lrate_in = tf.compat.v1.placeholder(tf.float32,
                                            name='lrate_in',
                                            shape=[])

        #print("DEBUG train:", "dataset iter got called")
        noisy_input, noisy_target, clean_target = dataset_iter.get_next()
        noisy_input_split = tf.split(noisy_input, submit_config.num_gpus)
        noisy_target_split = tf.split(noisy_target, submit_config.num_gpus)
        print(len(noisy_input_split), noisy_input_split)
        clean_target_split = tf.split(clean_target, submit_config.num_gpus)
        # Split [?, 3, 256, 256] across num_gpus over axis 0 (i.e. the batch)

    # Define the loss function using the Optimizer helper class, this will take care of multi GPU
    opt = tflib.Optimizer(learning_rate=lrate_in, **config.optimizer_config)
    radii = np.arange(128).reshape(128, 1)  #image size 256, binning = 3
    radial_masks = np.apply_along_axis(radial_mask, 1, radii, 128, 128,
                                       np.arange(0, 256), np.arange(0, 256),
                                       20)
    print("RN SHAPE!!!!!!!!!!:", radial_masks.shape)
    radial_masks = np.expand_dims(radial_masks, 1)  # (128, 1, 256, 256)
    #radial_masks = np.squeeze(np.stack((radial_masks,) * 3, -1)) # 43, 3, 256, 256
    #radial_masks = radial_masks.transpose([0, 3, 1, 2])
    radial_masks = radial_masks.astype(np.complex64)
    radial_masks = tf.expand_dims(radial_masks, 1)

    rn = tf.compat.v1.placeholder_with_default(radial_masks,
                                               [128, None, 1, 256, 256])
    rn_split = tf.split(rn, submit_config.num_gpus, axis=1)
    freq_nyq = int(np.floor(int(256) / 2.0))

    spatial_freq = radii.astype(np.float32) / freq_nyq
    spatial_freq = spatial_freq / max(spatial_freq)

    for gpu in range(submit_config.num_gpus):
        with tf.device("/gpu:%d" % gpu):
            net_gpu = net if gpu == 0 else net.clone()

            denoised_1 = net_gpu.get_output_for(noisy_input_split[gpu])
            denoised_2 = net_gpu.get_output_for(noisy_target_split[gpu])
            print(noisy_input_split[gpu].get_shape(),
                  rn_split[gpu].get_shape())
            if noise2noise:
                meansq_error = fourier_ring_correlation(
                    noisy_target_split[gpu], denoised_1, rn_split[gpu],
                    spatial_freq) - fourier_ring_correlation(
                        noisy_target_split[gpu] - denoised_2,
                        noisy_input_split[gpu] - denoised_1, rn_split[gpu],
                        spatial_freq)
            else:
                meansq_error = tf.reduce_mean(
                    tf.square(clean_target_split[gpu] - denoised))
            # Create an autosummary that will average over all GPUs
            #tf.summary.histogram(name, var)
            with tf.control_dependencies([autosummary("Loss", meansq_error)]):
                opt.register_gradients(meansq_error, net_gpu.trainables)

    train_step = opt.apply_updates()

    # Create a log file for Tensorboard
    summary_log = tf.compat.v1.summary.FileWriter(submit_config.run_dir)
    summary_log.add_graph(tf.compat.v1.get_default_graph())

    print('Training...')
    time_maintenance = ctx.get_time_since_last_update()
    ctx.update(loss='run %d' % submit_config.run_id,
               cur_epoch=0,
               max_epoch=iteration_count)

    # The actual training loop
    for i in range(iteration_count):
        # Whether to stop the training or not should be asked from the context
        if ctx.should_stop():
            break
        # Dump training status
        if i % eval_interval == 0:

            time_train = ctx.get_time_since_last_update()
            time_total = ctx.get_time_since_start()
            print("DEBUG TRAIN!", noisy_input.dtype, noisy_input[0][0].dtype)
            # Evaluate 'x' to draw a batch of inputs
            [source_mb, target_mb] = tfutil.run([noisy_input, clean_target])
            denoised = net.run(source_mb)
            save_image(submit_config, denoised[0],
                       "img_{0}_y_pred.tif".format(i))
            save_image(submit_config, target_mb[0], "img_{0}_y.tif".format(i))
            save_image(submit_config, source_mb[0],
                       "img_{0}_x_aug.tif".format(i))

            validation_set.evaluate(net, i,
                                    noise_augmenter.add_validation_noise_np)

            print(
                'iter %-10d time %-12s sec/eval %-7.1f sec/iter %-7.2f maintenance %-6.1f'
                % (autosummary('Timing/iter', i),
                   dnnlib.util.format_time(
                       autosummary('Timing/total_sec', time_total)),
                   autosummary('Timing/sec_per_eval', time_train),
                   autosummary('Timing/sec_per_iter',
                               time_train / eval_interval),
                   autosummary('Timing/maintenance_sec', time_maintenance)))

            dnnlib.tflib.autosummary.save_summaries(summary_log, i)
            ctx.update(loss='run %d' % submit_config.run_id,
                       cur_epoch=i,
                       max_epoch=iteration_count)
            time_maintenance = ctx.get_last_update_interval() - time_train

            save_snapshot(submit_config, net, str(i))
        lrate = compute_ramped_down_lrate(i, iteration_count, ramp_down_perc,
                                          learning_rate)
        tfutil.run([train_step], {lrate_in: lrate})

    print("Elapsed time: {0}".format(
        util.format_time(ctx.get_time_since_start())))
    save_snapshot(submit_config, net, 'final')

    # Summary log and context should be closed at the end
    summary_log.close()
    ctx.close()