if __name__ == '__main__': opt = TrainOptions().parse() os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu output_dir = opt.output_dir if not os.path.exists(os.path.join(output_dir, 'results')): os.makedirs(os.path.join(output_dir, 'results')) # set attrTable graph_kwargs = util.set_graph_kwargs(opt) graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util') constants = importlib.import_module('graphs.' + opt.model + '.constants') model = graphs.find_model_using_name(opt.model, opt.transform) g = model(**graph_kwargs) num_samples = opt.num_samples graph_inputs = graph_util.graph_input(g, num_samples, seed=0) if opt.suffix: name = opt.suffix else: name = None attrList = graph_kwargs['attrList'] layers = opt.layers print('attrlist: ', attrList)
def joint_train( submit_config, opt, metric_arg_list, sched_args = {}, # 训练计划设置。 grid_args = {}, # setup_snapshot_image_grid()相关设置。 dataset_args = {}, # 数据集设置。 total_kimg = 15000, # 训练的总长度,以成千上万个真实图像为统计。 drange_net = [-1,1], # 将图像数据馈送到网络时使用的动态范围。 image_snapshot_ticks = 1, # 多久导出一次图像快照? network_snapshot_ticks = 10, # 多久导出一次网络模型存储? D_repeats = 1, # G每迭代一次训练判别器多少次。 minibatch_repeats = 4, # 调整训练参数前要运行的minibatch的数量。 mirror_augment = False, # 启用镜像增强? reset_opt_for_new_lod = True, # 引入新层时是否重置优化器内部状态(例如Adam时刻)? save_tf_graph = False, # 在tfevents文件中包含完整的TensorFlow计算图吗? save_weight_histograms = False, # 在tfevents文件中包括权重直方图? resume_run_id = None, # 运行已有ID或载入已有网络pkl以从中恢复训练,None = 从头开始。 resume_snapshot = None, # 要从哪恢复训练的快照的索引,None = 自动检测。 resume_kimg = 0.0, # 在训练开始时给定当前训练进度。影响报告和训练计划。 resume_time = 0.0, # 在训练开始时给定统计时间。影响报告。 *args, **kwargs ): output_dir = opt.output_dir graph_kwargs = util.set_graph_kwargs(opt) graph_util = importlib.import_module('graphs.' + opt.model + '.graph_util') constants = importlib.import_module('graphs.' + opt.model + '.constants') model = graphs.find_model_using_name(opt.model, opt.transform) g = model(submit_config=submit_config, dataset_args=dataset_args, **graph_kwargs, **kwargs) g.initialize_graph() # create training samples #num_samples = opt.num_samples # if opt.model == 'biggan' and opt.biggan.category is not None: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0, category=opt.biggan.category) # else: # graph_inputs = graph_util.graph_input(g, num_samples, seed=0) w_snapshot_ticks = opt.model_save_freq ctx = dnnlib.RunContext(submit_config, train) training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **dataset_args) with tf.device('/gpu:0'): try: peak_gpu_mem_op = tf.contrib.memory_stats.MaxBytesInUse() except tf.errors.NotFoundError: peak_gpu_mem_op = tf.constant(0) # 设置快照图像网格 print('Setting up snapshot image grid...') grid_size, grid_reals, grid_labels, grid_latents = misc.setup_snapshot_image_grid(g.G, training_set, **grid_args) sched = training_loop.training_schedule(cur_nimg=total_kimg*1000, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) # 建立运行目录 print('Setting up run dir...') misc.save_image_grid(grid_reals, os.path.join(submit_config.run_dir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % resume_kimg), drange=drange_net, grid_size=grid_size) summary_log = tf.summary.FileWriter(submit_config.run_dir) if save_tf_graph: summary_log.add_graph(tf.get_default_graph()) if save_weight_histograms: g.G.setup_weight_histograms(); g.D.setup_weight_histograms() metrics = metric_base.MetricGroup(metric_arg_list) # 训练 print('Training...\n') ctx.update('', cur_epoch=resume_kimg, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() cur_nimg = int(resume_kimg * 1000) cur_tick = 0 tick_start_nimg = cur_nimg prev_lod = -1.0 loss_values = [] while cur_nimg < total_kimg * 1000: if ctx.should_stop(): break # 选择训练参数并配置训练操作。 sched = training_loop.training_schedule(cur_nimg=cur_nimg, training_set=training_set, num_gpus=submit_config.num_gpus, **sched_args) training_set.configure(sched.minibatch // submit_config.num_gpus, sched.lod) if reset_opt_for_new_lod: if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): g.G_opt.reset_optimizer_state(); # D_opt.reset_optimizer_state() prev_lod = sched.lod # 进行训练。 for _mb_repeat in range(minibatch_repeats): alpha_for_graph, alpha_for_target = g.get_train_alpha(constants.BATCH_SIZE) if not isinstance(alpha_for_graph, list): alpha_for_graph = [alpha_for_graph] alpha_for_target = [alpha_for_target] for ag, at in zip(alpha_for_graph, alpha_for_target): feed_dict_out = graph_util.graph_input(g, constants.BATCH_SIZE, seed=0) out_zs = g.sess.run(g.outputs_orig, feed_dict_out) target_fn, mask_out = g.get_target_np(out_zs, at) feed_dict = feed_dict_out feed_dict[g.alpha] = ag feed_dict[g.target] = target_fn feed_dict[g.mask] = mask_out feed_dict[g.lod_in] = sched.lod feed_dict[g.lrate_in] = sched.D_lrate feed_dict[g.minibatch_in] = sched.minibatch curr_loss, _, Gs_op, G_op = g.sess.run([g.joint_loss, g.train_step, g.Gs_update_op, g.G_train_op], feed_dict=feed_dict) loss_values.append(curr_loss) cur_nimg += sched.minibatch #tflib.run([g.Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) #tflib.run([g.G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) # 每个tick执行一次维护任务。 done = (cur_nimg >= total_kimg * 1000) if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: cur_tick += 1 tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 tick_start_nimg = cur_nimg tick_time = ctx.get_time_since_last_update() total_time = ctx.get_time_since_start() + resume_time # 报告进度。 print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %-6.1f gpumem %-4.1f' % ( autosummary('Progress/tick', cur_tick), autosummary('Progress/kimg', cur_nimg / 1000.0), autosummary('Progress/lod', sched.lod), autosummary('Progress/minibatch', sched.minibatch), dnnlib.util.format_time(autosummary('Timing/total_sec', total_time)), autosummary('Timing/sec_per_tick', tick_time), autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), autosummary('Timing/maintenance_sec', maintenance_time), autosummary('Resources/peak_gpu_mem_gb', peak_gpu_mem_op.eval() / 2**30))) autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) # 保存快照。 if cur_tick % image_snapshot_ticks == 0 or done: grid_fakes = g.Gs.run(grid_latents, grid_labels, is_validation=True, minibatch_size=sched.minibatch//submit_config.num_gpus) misc.save_image_grid(grid_fakes, os.path.join(submit_config.run_dir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) if cur_tick % network_snapshot_ticks == 0 or done or cur_tick == 1: pkl = os.path.join(submit_config.run_dir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)) misc.save_pkl((g.G, g.D, g.Gs), pkl) metrics.run(pkl, run_dir=submit_config.run_dir, num_gpus=submit_config.num_gpus, tf_config=tf_config) if cur_tick % w_snapshot_ticks == 0 or done: g.saver.save(g.sess, './{}/model_{}.ckpt'.format( output_dir, (cur_nimg // 1000)), write_meta_graph=False, write_state=False) # 更新摘要和RunContext。 metrics.update_autosummaries() tflib.autosummary.save_summaries(summary_log, cur_nimg) ctx.update('%.2f' % sched.lod, cur_epoch=cur_nimg // 1000, max_epoch=total_kimg) maintenance_time = ctx.get_last_update_interval() - tick_time # 保存最终结果。 misc.save_pkl((g.G, g.D, g.Gs), os.path.join(submit_config.run_dir, 'network-final.pkl')) summary_log.close() ctx.close() loss_values = np.array(loss_values) np.save('./{}/loss_values.npy'.format(output_dir), loss_values) f, ax = plt.subplots(figsize=(10, 4)) ax.plot(loss_values) f.savefig('./{}/loss_values.png'.format(output_dir))
print('Load face recognition model') if opt.output_dir: output_dir = opt.output_dir else: output_dir = os.path.join(conf.output_dir, 'images') os.makedirs(output_dir, exist_ok=True) graph_kwargs = util.set_graph_kwargs(conf) print('Load utils and constants: %s' % conf.model) graph_util = importlib.import_module('graphs.' + conf.model + '.graph_util') constants = importlib.import_module('graphs.' + conf.model + '.constants') print('Find_model_using_name') model = graphs.find_model_using_name(conf.model, conf.transform) print('Model initialization') g = model(**graph_kwargs) print('Load multi models') if opt.updateGAN: g.load_multi_models(opt.save_path_w, opt.save_path_gan, trainEmbed=opt.trainEmbed, updateGAN=opt.updateGAN) else: g.load_multi_models(opt.save_path_w, None, trainEmbed=opt.trainEmbed, updateGAN=opt.updateGAN)
def main(): tOption = TrainOptions() for key, val in Params().__dict__.items(): tOption.parser.add_argument('--{}'.format(key), type=type(val), default=val) tOption.parser.add_argument('--args', type=str, default=None, help='json with all arguments') tOption.parser.add_argument('--out', type=str, default='./output') tOption.parser.add_argument('--gan_type', type=str, choices=WEIGHTS.keys(), default='StyleGAN') tOption.parser.add_argument('--gan_weights', type=str, default=None) tOption.parser.add_argument('--target_class', type=int, default=239) tOption.parser.add_argument('--json', type=str) tOption.parser.add_argument('--deformator', type=str, default='proj', choices=DEFORMATOR_TYPE_DICT.keys()) tOption.parser.add_argument('--deformator_random_init', type=bool, default=False) tOption.parser.add_argument('--shift_predictor_size', type=int) tOption.parser.add_argument('--shift_predictor', type=str, choices=['ResNet', 'LeNet'], default='ResNet') tOption.parser.add_argument('--shift_distribution_key', type=str, choices=SHIFT_DISTRIDUTION_DICT.keys()) tOption.parser.add_argument('--seed', type=int, default=2) tOption.parser.add_argument('--device', type=int, default=0) tOption.parser.add_argument('--continue_train', type=bool, default=False) tOption.parser.add_argument('--deformator_path', type=str, default='output/models/deformator_90000.pt') tOption.parser.add_argument( '--shift_predictor_path', type=str, default='output/models/shift_predictor_190000.pt') args = tOption.parse() torch.cuda.set_device(args.device) random.seed(args.seed) torch.random.manual_seed(args.seed) if args.args is not None: with open(args.args) as args_json: args_dict = json.load(args_json) args.__dict__.update(**args_dict) # save run params #if not os.path.isdir(args.out): # os.makedirs(args.out) #with open(os.path.join(args.out, 'args.json'), 'w') as args_file: # json.dump(args.__dict__, args_file) #with open(os.path.join(args.out, 'command.sh'), 'w') as command_file: # command_file.write(' '.join(sys.argv)) # command_file.write('\n') # init models if args.gan_weights is not None: weights_path = args.gan_weights else: weights_path = WEIGHTS[args.gan_type] if args.gan_type == 'BigGAN': G = make_big_gan(weights_path, args.target_class).eval() elif args.gan_type == 'StyleGAN': G = make_stylegan( weights_path, net_info[args.stylegan.dataset]['resolution']).eval() elif args.gan_type == 'ProgGAN': G = make_proggan(weights_path).eval() else: G = make_external(weights_path).eval() #判断是对z还是w做latent code if args.model == 'stylegan': assert (args.stylegan.latent in ['z', 'w']), 'unknown latent space' if args.stylegan.latent == 'z': target_dim = G.dim_z else: target_dim = G.dim_w if args.shift_predictor == 'ResNet': shift_predictor = ResNetShiftPredictor( args.direction_size, args.shift_predictor_size).cuda() elif args.shift_predictor == 'LeNet': shift_predictor = LeNetShiftPredictor( args.direction_size, 1 if args.gan_type == 'SN_MNIST' else 3).cuda() if args.continue_train: deformator = LatentDeformator( direction_size=args.direction_size, out_dim=target_dim, type=DEFORMATOR_TYPE_DICT[args.deformator]).cuda() deformator.load_state_dict( torch.load(args.deformator_path, map_location=torch.device('cpu'))) shift_predictor.load_state_dict( torch.load(args.shift_predictor_path, map_location=torch.device('cpu'))) else: deformator = LatentDeformator( direction_size=args.direction_size, out_dim=target_dim, type=DEFORMATOR_TYPE_DICT[args.deformator], random_init=args.deformator_random_init).cuda() # transform graph_kwargs = util.set_graph_kwargs(args) transform_type = ['zoom', 'shiftx', 'color', 'shifty'] transform_model = EasyDict() for a_type in transform_type: model = graphs.find_model_using_name(args.model, a_type) g = model(**graph_kwargs) transform_model[a_type] = EasyDict(model=g) # training args.shift_distribution = SHIFT_DISTRIDUTION_DICT[ args.shift_distribution_key] trainer = Trainer(params=Params(**args.__dict__), out_dir=args.out, out_json=args.json, continue_train=args.continue_train) trainer.train(G, deformator, shift_predictor, transform_model)