def parse_config_txt(run_id) -> dict: result_subdir = locate_result_subdir(run_id) # Parse config.txt. parsed_cfg = config.EasyDict() with open(os.path.join(result_subdir, 'config.txt'), 'rt') as f: for line in f: if line.startswith('dataset =') or line.startswith('train ='): try: exec(line, parsed_cfg, parsed_cfg) except SyntaxError as e: b = re.search( r", 'blur_schedule_type': <BlurScheduleType\.((.+): ?('?(.+))'?)>", line) if b and line.startswith("train ="): line = line.replace(b.group(0), "") exec(line, parsed_cfg, parsed_cfg) blur_type_str = b.group(2) if blur_type_str == "NONE": blur_type_str = "NOBLUR" blur_type = BlurScheduleType[blur_type_str] parsed_cfg.train = config.EasyDict(parsed_cfg.train) parsed_cfg.train.blur_schedule_type = blur_type continue return parsed_cfg
def restore_config(resume_run_id, config): """ Returns all the necessary info to restart a run from the latest snapshot. """ network_pkls = list_network_pkls(resume_run_id, include_final=False) print("Network snapshots available: (except the final)", network_pkls) network_pkl = network_pkls[-1] id_string = get_id_string_for_network_pkl(network_pkl) kimg = int(id_string.split("-")[-1].replace(".pkl", "")) parsed_config = parse_config_txt(resume_run_id) parsed_config = config.EasyDict(parsed_config) config.desc = parsed_config.get('desc', config.desc) config.train.update(parsed_config.get('train', dict())) config.dataset.update(parsed_config.get('dataset', dict())) from datetime import timedelta src_result_subdir = locate_result_subdir(resume_run_id) # Parse log. times = [] snaps = [] # [(png, kimg, lod), ...] with open(os.path.join(src_result_subdir, 'log.txt'), 'rt') as log: for line in log: tick = re.search(r'tick ([\d]+) ', line) k = re.search(r'kimg ([\d\.]+) ', line) t = re.search(r'time (\d+d)? *(\d+h)? *(\d+m)? *(\d+s)? ', line) if tick and k and t: tick = int(tick.group(1)) k = int(float(k.group(1))) t = [ int(t.group(i)[:-1]) if t.group(i) else 0 for i in range(1, 5) ] t_delta = timedelta(days=t[0], hours=t[1], minutes=t[2], seconds=t[3]) if k == kimg: print("PREVIOUS RUN FOUND.") config.train.resume_tick = tick config.train.resume_run_id = resume_run_id config.train.resume_kimg = kimg config.train.resume_time = t_delta.total_seconds() break
parser.add_argument('--model_path', type=str) # path of the pretrained model .pkl file parser.add_argument('--out_dir', type=str) # path for saving the generated data (default: save to model dir) parser.add_argument('--num_samples', '-ns', type=int, default=20000) # number of samples parser.add_argument('--gen_seed', type=int, default=1000) # random seed args = parser.parse_args() if args.app == 'train': misc.init_output_logging() np.random.seed(config.random_seed) print('Initializing TensorFlow...') os.environ.update(config.env) tfutil.init_tf(config.tf_config) print('Running %s()...' % config.train['func']) app = config.train elif args.app == 'gen': misc.init_output_logging() np.random.seed(args.gen_seed) print('Initializing TensorFlow...') os.environ.update(config.env) tfutil.init_tf(config.tf_config) out_dir = os.path.dirname(args.model_path) if args.out_dir is None else args.out_dir app = config.EasyDict(func='util_scripts.generate_fake_images', model_path=args.model_path, out_dir=out_dir, num_samples=args.num_samples, random_seed=args.gen_seed) tfutil.call_func_by_name(**app) # ----------------------------------------------------------------------------
'--gen_seed', type=int, default=9999) # The random seed to differentiate generation instances args = parser.parse_args() if args.app == 'train': assert args.training_data_dir != ' ' and args.out_model_dir != ' ' misc.init_output_logging() np.random.seed(args.training_seed) print('Initializing TensorFlow...') os.environ.update(config.env) tfutil.init_tf(config.tf_config) if args.training_data_dir[-1] == '/': args.training_data_dir = args.training_data_dir[:-1] idx = args.training_data_dir.rfind('/') config.data_dir = args.training_data_dir[:idx] config.dataset = config.EasyDict( tfrecord_dir=args.training_data_dir[idx + 1:]) app = config.EasyDict(func='run.train_progressive_gan', mirror_augment=False, total_kimg=12000) config.result_dir = args.out_model_dir elif args.app == 'gen': assert args.model_path != ' ' and args.out_image_dir != ' ' misc.init_output_logging() np.random.seed(args.gen_seed) print('Initializing TensorFlow...') os.environ.update(config.env) tfutil.init_tf(config.tf_config) app = config.EasyDict(func='util_scripts.generate_fake_images', pkl_path=args.model_path, out_dir=args.out_image_dir, num_pngs=args.num_pngs,
args = parser.parse_args() if args.app == 'train': assert args.training_data_dir != ' ' and args.out_model_dir != ' ' if args.validation_data_dir == ' ': args.validation_data_dir = args.training_data_dir misc.init_output_logging() np.random.seed(args.training_seed) print('Initializing TensorFlow...') os.environ.update(config.env) tfutil.init_tf(config.tf_config) if args.training_data_dir[-1] == '/': args.training_data_dir = args.training_data_dir[:-1] idx = args.training_data_dir.rfind('/') config.data_dir = args.training_data_dir[:idx] config.training_set = config.EasyDict( tfrecord_dir=args.training_data_dir[idx + 1:], max_label_size='full') if args.validation_data_dir[-1] == '/': args.validation_data_dir = args.validation_data_dir[:-1] idx = args.validation_data_dir.rfind('/') config.validation_set = config.EasyDict( tfrecord_dir=args.validation_data_dir[idx + 1:], max_label_size='full') app = config.EasyDict(func='run.train_classifier', lr_mirror_augment=True, ud_mirror_augment=False, total_kimg=25000) config.result_dir = args.out_model_dir elif args.app == 'test': assert args.model_path != ' ' and args.testing_data_path != ' ' and args.out_fingerprint_dir != ' ' misc.init_output_logging()