コード例 #1
0
if __name__ == '__main__':

    opt = TrainOptions().parse()
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu

    output_dir = opt.output_dir
    if not os.path.exists(os.path.join(output_dir, 'results')):
        os.makedirs(os.path.join(output_dir, 'results'))

    # set attrTable
    graph_kwargs = util.set_graph_kwargs(opt)

    graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util')
    constants = importlib.import_module('graphs.' + opt.model + '.constants')
    model = graphs.find_model_using_name(opt.model, opt.transform)

    g = model(**graph_kwargs)

    num_samples = opt.num_samples
    graph_inputs = graph_util.graph_input(g, num_samples, seed=0)

    if opt.suffix:
        name = opt.suffix
    else:
        name = None

    attrList = graph_kwargs['attrList']
    layers = opt.layers

    print('attrlist: ', attrList)
コード例 #2
0
def joint_train(
    submit_config,
    opt,
    metric_arg_list,
    sched_args              = {},       # 训练计划设置。
    grid_args               = {},       # setup_snapshot_image_grid()相关设置。
    dataset_args            = {},       # 数据集设置。
    total_kimg              = 15000,    # 训练的总长度,以成千上万个真实图像为统计。
    drange_net              = [-1,1],   # 将图像数据馈送到网络时使用的动态范围。
    image_snapshot_ticks    = 1,        # 多久导出一次图像快照?
    network_snapshot_ticks  = 10,       # 多久导出一次网络模型存储?
    D_repeats               = 1,        # G每迭代一次训练判别器多少次。
    minibatch_repeats       = 4,        # 调整训练参数前要运行的minibatch的数量。
    mirror_augment          = False,    # 启用镜像增强?
    reset_opt_for_new_lod   = True,     # 引入新层时是否重置优化器内部状态(例如Adam时刻)?
    save_tf_graph           = False,    # 在tfevents文件中包含完整的TensorFlow计算图吗?
    save_weight_histograms  = False,    # 在tfevents文件中包括权重直方图?
    resume_run_id           = None,     # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。
    resume_snapshot         = None,     # 要从哪恢复训练的快照的索引,None = 自动检测。
    resume_kimg             = 0.0,      # 在训练开始时给定当前训练进度。影响报告和训练计划。
    resume_time             = 0.0,     # 在训练开始时给定统计时间。影响报告。
    *args,
    **kwargs
    ):

    output_dir = opt.output_dir

    graph_kwargs = util.set_graph_kwargs(opt)

    graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util')
    constants = importlib.import_module('graphs.' + opt.model + '.constants')

    model = graphs.find_model_using_name(opt.model, opt.transform)
    g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs)
    g.initialize_graph()

    # create training samples
    #num_samples = opt.num_samples
    # if opt.model == 'biggan' and opt.biggan.category is not None:
    #     graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category)
    # else:
    #     graph_inputs = graph_util.graph_input(g, num_samples, seed=0)



    w_snapshot_ticks = opt.model_save_freq

    ctx = dnnlib.RunContext(submit_config, train)
    training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args)
    
    with tf.device('/gpu:0'):
        try:
            peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse()
        except tf.errors.NotFoundError:
            peak_gpu_mem_op = tf.constant(0)
    
    # 设置快照图像网格
    print('Setting up snapshot image grid...')
    grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args)
    sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
    grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
    # 建立运行目录
    print('Setting up run dir...')
    misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
    misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size)
    summary_log = tf.summary.FileWriter(submit_config.run_dir)
    if save_tf_graph:
        summary_log.add_graph(tf.get_default_graph())
    if save_weight_histograms:
        g.G.setup_weight_histograms(); g.D.setup_weight_histograms()
    metrics = metric_base.MetricGroup(metric_arg_list)
    # 训练
    print('Training...\n')
    ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg)
    maintenance_time = ctx.get_last_update_interval()
    cur_nimg = int(resume_kimg * 1000)
    cur_tick = 0
    tick_start_nimg = cur_nimg
    prev_lod = -1.0
    loss_values = []
    while cur_nimg < total_kimg * 1000:
        if ctx.should_stop(): break

        # 选择训练参数并配置训练操作。
        sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args)
        training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod)
        if reset_opt_for_new_lod:
            if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
                g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state()
        prev_lod = sched.lod

        # 进行训练。
        for _mb_repeat in range(minibatch_repeats):
            alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE)
            if not isinstance(alpha_for_graph, list):
                alpha_for_graph = [alpha_for_graph]
                alpha_for_target = [alpha_for_target]
            for ag, at in zip(alpha_for_graph, alpha_for_target):
                feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0)
                out_zs = g.sess.run(g.outputs_orig, feed_dict_out)

                target_fn, mask_out = g.get_target_np(out_zs, at)
                feed_dict = feed_dict_out
                feed_dict[g.alpha] = ag
                feed_dict[g.target] = target_fn
                feed_dict[g.mask] = mask_out
                feed_dict[g.lod_in] = sched.lod
                feed_dict[g.lrate_in] = sched.D_lrate
                feed_dict[g.minibatch_in] = sched.minibatch
                curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict)
                loss_values.append(curr_loss)
            
            cur_nimg += sched.minibatch
            #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
            #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})

        # 每个tick执行一次维护任务。
        done = (cur_nimg >= total_kimg * 1000)
        if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
            cur_tick += 1
            tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
            tick_start_nimg = cur_nimg
            tick_time = ctx.get_time_since_last_update()
            total_time = ctx.get_time_since_start() + resume_time

            # 报告进度。
            print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % (
                autosummary('Progress/tick', cur_tick),
                autosummary('Progress/kimg', cur_nimg / 1000.0),
                autosummary('Progress/lod', sched.lod),
                autosummary('Progress/minibatch', sched.minibatch),
                dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)),
                autosummary('Timing/sec_per_tick', tick_time),
                autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
                autosummary('Timing/maintenance_sec', maintenance_time),
                autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30)))
            autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
            autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))

            # 保存快照。
            if cur_tick % image_snapshot_ticks == 0 or done:
                grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus)
                misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
            if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1:
                pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))
                misc.save_pkl((g.G, g.D, g.Gs), pkl)
                metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config)
            if cur_tick % w_snapshot_ticks == 0 or done:
                g.saver.save(g.sess, './{}/model_{}.ckpt'.format(
                    output_dir, (cur_nimg // 1000)),
                    write_meta_graph=False, write_state=False)

            # 更新摘要和RunContext。
            metrics.update_autosummaries()
            tflib.autosummary.save_summaries(summary_log, cur_nimg)
            ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg)
            maintenance_time = ctx.get_last_update_interval() - tick_time

    # 保存最终结果。
    misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl'))
    summary_log.close()

    ctx.close()

    loss_values = np.array(loss_values)
    np.save('./{}/loss_values.npy'.format(output_dir), loss_values)
    f, ax  = plt.subplots(figsize=(10, 4))
    ax.plot(loss_values)
    f.savefig('./{}/loss_values.png'.format(output_dir))
コード例 #3
0
ファイル: eval.py プロジェクト: KelestZ/Latent2im
    print('Load face recognition model')

    if opt.output_dir:
        output_dir = opt.output_dir
    else:
        output_dir = os.path.join(conf.output_dir, 'images')
    os.makedirs(output_dir, exist_ok=True)

    graph_kwargs = util.set_graph_kwargs(conf)
    print('Load utils and constants: %s' % conf.model)
    graph_util = importlib.import_module('graphs.' + conf.model +
                                         '.graph_util')
    constants = importlib.import_module('graphs.' + conf.model + '.constants')

    print('Find_model_using_name')
    model = graphs.find_model_using_name(conf.model, conf.transform)

    print('Model initialization')
    g = model(**graph_kwargs)
    print('Load multi models')
    if opt.updateGAN:
        g.load_multi_models(opt.save_path_w,
                            opt.save_path_gan,
                            trainEmbed=opt.trainEmbed,
                            updateGAN=opt.updateGAN)
    else:
        g.load_multi_models(opt.save_path_w,
                            None,
                            trainEmbed=opt.trainEmbed,
                            updateGAN=opt.updateGAN)
コード例 #4
0
def main():
    tOption = TrainOptions()

    for key, val in Params().__dict__.items():
        tOption.parser.add_argument('--{}'.format(key),
                                    type=type(val),
                                    default=val)

    tOption.parser.add_argument('--args',
                                type=str,
                                default=None,
                                help='json with all arguments')
    tOption.parser.add_argument('--out', type=str, default='./output')
    tOption.parser.add_argument('--gan_type',
                                type=str,
                                choices=WEIGHTS.keys(),
                                default='StyleGAN')
    tOption.parser.add_argument('--gan_weights', type=str, default=None)
    tOption.parser.add_argument('--target_class', type=int, default=239)
    tOption.parser.add_argument('--json', type=str)

    tOption.parser.add_argument('--deformator',
                                type=str,
                                default='proj',
                                choices=DEFORMATOR_TYPE_DICT.keys())
    tOption.parser.add_argument('--deformator_random_init',
                                type=bool,
                                default=False)

    tOption.parser.add_argument('--shift_predictor_size', type=int)
    tOption.parser.add_argument('--shift_predictor',
                                type=str,
                                choices=['ResNet', 'LeNet'],
                                default='ResNet')
    tOption.parser.add_argument('--shift_distribution_key',
                                type=str,
                                choices=SHIFT_DISTRIDUTION_DICT.keys())

    tOption.parser.add_argument('--seed', type=int, default=2)
    tOption.parser.add_argument('--device', type=int, default=0)

    tOption.parser.add_argument('--continue_train', type=bool, default=False)
    tOption.parser.add_argument('--deformator_path',
                                type=str,
                                default='output/models/deformator_90000.pt')
    tOption.parser.add_argument(
        '--shift_predictor_path',
        type=str,
        default='output/models/shift_predictor_190000.pt')

    args = tOption.parse()
    torch.cuda.set_device(args.device)
    random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    if args.args is not None:
        with open(args.args) as args_json:
            args_dict = json.load(args_json)
            args.__dict__.update(**args_dict)

    # save run params
    #if not os.path.isdir(args.out):
    #    os.makedirs(args.out)
    #with open(os.path.join(args.out, 'args.json'), 'w') as args_file:
    #    json.dump(args.__dict__, args_file)
    #with open(os.path.join(args.out, 'command.sh'), 'w') as command_file:
    #    command_file.write(' '.join(sys.argv))
    #    command_file.write('\n')

    # init models
    if args.gan_weights is not None:
        weights_path = args.gan_weights
    else:
        weights_path = WEIGHTS[args.gan_type]

    if args.gan_type == 'BigGAN':
        G = make_big_gan(weights_path, args.target_class).eval()
    elif args.gan_type == 'StyleGAN':
        G = make_stylegan(
            weights_path,
            net_info[args.stylegan.dataset]['resolution']).eval()
    elif args.gan_type == 'ProgGAN':
        G = make_proggan(weights_path).eval()
    else:
        G = make_external(weights_path).eval()

    #判断是对z还是w做latent code
    if args.model == 'stylegan':
        assert (args.stylegan.latent in ['z', 'w']), 'unknown latent space'
        if args.stylegan.latent == 'z':
            target_dim = G.dim_z
        else:
            target_dim = G.dim_w

    if args.shift_predictor == 'ResNet':
        shift_predictor = ResNetShiftPredictor(
            args.direction_size, args.shift_predictor_size).cuda()
    elif args.shift_predictor == 'LeNet':
        shift_predictor = LeNetShiftPredictor(
            args.direction_size,
            1 if args.gan_type == 'SN_MNIST' else 3).cuda()
    if args.continue_train:
        deformator = LatentDeformator(
            direction_size=args.direction_size,
            out_dim=target_dim,
            type=DEFORMATOR_TYPE_DICT[args.deformator]).cuda()
        deformator.load_state_dict(
            torch.load(args.deformator_path, map_location=torch.device('cpu')))

        shift_predictor.load_state_dict(
            torch.load(args.shift_predictor_path,
                       map_location=torch.device('cpu')))
    else:
        deformator = LatentDeformator(
            direction_size=args.direction_size,
            out_dim=target_dim,
            type=DEFORMATOR_TYPE_DICT[args.deformator],
            random_init=args.deformator_random_init).cuda()

    # transform
    graph_kwargs = util.set_graph_kwargs(args)

    transform_type = ['zoom', 'shiftx', 'color', 'shifty']
    transform_model = EasyDict()
    for a_type in transform_type:
        model = graphs.find_model_using_name(args.model, a_type)
        g = model(**graph_kwargs)
        transform_model[a_type] = EasyDict(model=g)

    # training
    args.shift_distribution = SHIFT_DISTRIDUTION_DICT[
        args.shift_distribution_key]
    trainer = Trainer(params=Params(**args.__dict__),
                      out_dir=args.out,
                      out_json=args.json,
                      continue_train=args.continue_train)
    trainer.train(G, deformator, shift_predictor, transform_model)