def parse_config_txt(run_id) -> dict:
    result_subdir = locate_result_subdir(run_id)
    # Parse config.txt.
    parsed_cfg = config.EasyDict()
    with open(os.path.join(result_subdir, 'config.txt'), 'rt') as f:
        for line in f:
            if line.startswith('dataset =') or line.startswith('train ='):
                try:
                    exec(line, parsed_cfg, parsed_cfg)
                except SyntaxError as e:
                    b = re.search(
                        r", 'blur_schedule_type': <BlurScheduleType\.((.+): ?('?(.+))'?)>",
                        line)
                    if b and line.startswith("train ="):
                        line = line.replace(b.group(0), "")
                        exec(line, parsed_cfg, parsed_cfg)
                        blur_type_str = b.group(2)
                        if blur_type_str == "NONE":
                            blur_type_str = "NOBLUR"
                        blur_type = BlurScheduleType[blur_type_str]
                        parsed_cfg.train = config.EasyDict(parsed_cfg.train)
                        parsed_cfg.train.blur_schedule_type = blur_type
                        continue
    return parsed_cfg
def restore_config(resume_run_id, config):
    """
    Returns all the necessary info to restart a run from the latest snapshot.
    """
    network_pkls = list_network_pkls(resume_run_id, include_final=False)
    print("Network snapshots available: (except the final)", network_pkls)
    network_pkl = network_pkls[-1]
    id_string = get_id_string_for_network_pkl(network_pkl)
    kimg = int(id_string.split("-")[-1].replace(".pkl", ""))
    parsed_config = parse_config_txt(resume_run_id)
    parsed_config = config.EasyDict(parsed_config)

    config.desc = parsed_config.get('desc', config.desc)

    config.train.update(parsed_config.get('train', dict()))
    config.dataset.update(parsed_config.get('dataset', dict()))

    from datetime import timedelta
    src_result_subdir = locate_result_subdir(resume_run_id)
    # Parse log.
    times = []
    snaps = []  # [(png, kimg, lod), ...]
    with open(os.path.join(src_result_subdir, 'log.txt'), 'rt') as log:
        for line in log:
            tick = re.search(r'tick ([\d]+) ', line)
            k = re.search(r'kimg ([\d\.]+) ', line)
            t = re.search(r'time (\d+d)? *(\d+h)? *(\d+m)? *(\d+s)? ', line)
            if tick and k and t:
                tick = int(tick.group(1))
                k = int(float(k.group(1)))
                t = [
                    int(t.group(i)[:-1]) if t.group(i) else 0
                    for i in range(1, 5)
                ]
                t_delta = timedelta(days=t[0],
                                    hours=t[1],
                                    minutes=t[2],
                                    seconds=t[3])
                if k == kimg:
                    print("PREVIOUS RUN FOUND.")
                    config.train.resume_tick = tick
                    config.train.resume_run_id = resume_run_id
                    config.train.resume_kimg = kimg
                    config.train.resume_time = t_delta.total_seconds()
                    break
示例#3
0
    parser.add_argument('--model_path', type=str)  # path of the pretrained model .pkl file
    parser.add_argument('--out_dir', type=str)  # path for saving the generated data (default: save to model dir)
    parser.add_argument('--num_samples', '-ns', type=int, default=20000)  # number of samples
    parser.add_argument('--gen_seed', type=int, default=1000)  # random seed

    args = parser.parse_args()
    if args.app == 'train':
        misc.init_output_logging()
        np.random.seed(config.random_seed)
        print('Initializing TensorFlow...')
        os.environ.update(config.env)
        tfutil.init_tf(config.tf_config)
        print('Running %s()...' % config.train['func'])
        app = config.train

    elif args.app == 'gen':
        misc.init_output_logging()
        np.random.seed(args.gen_seed)
        print('Initializing TensorFlow...')
        os.environ.update(config.env)
        tfutil.init_tf(config.tf_config)
        out_dir = os.path.dirname(args.model_path) if args.out_dir is None else args.out_dir
        app = config.EasyDict(func='util_scripts.generate_fake_images',
                              model_path=args.model_path,
                              out_dir=out_dir,
                              num_samples=args.num_samples,
                              random_seed=args.gen_seed)

    tfutil.call_func_by_name(**app)
# ----------------------------------------------------------------------------
示例#4
0
        '--gen_seed', type=int,
        default=9999)  # The random seed to differentiate generation instances

    args = parser.parse_args()
    if args.app == 'train':
        assert args.training_data_dir != ' ' and args.out_model_dir != ' '
        misc.init_output_logging()
        np.random.seed(args.training_seed)
        print('Initializing TensorFlow...')
        os.environ.update(config.env)
        tfutil.init_tf(config.tf_config)
        if args.training_data_dir[-1] == '/':
            args.training_data_dir = args.training_data_dir[:-1]
        idx = args.training_data_dir.rfind('/')
        config.data_dir = args.training_data_dir[:idx]
        config.dataset = config.EasyDict(
            tfrecord_dir=args.training_data_dir[idx + 1:])
        app = config.EasyDict(func='run.train_progressive_gan',
                              mirror_augment=False,
                              total_kimg=12000)
        config.result_dir = args.out_model_dir
    elif args.app == 'gen':
        assert args.model_path != ' ' and args.out_image_dir != ' '
        misc.init_output_logging()
        np.random.seed(args.gen_seed)
        print('Initializing TensorFlow...')
        os.environ.update(config.env)
        tfutil.init_tf(config.tf_config)
        app = config.EasyDict(func='util_scripts.generate_fake_images',
                              pkl_path=args.model_path,
                              out_dir=args.out_image_dir,
                              num_pngs=args.num_pngs,
示例#5
0
 args = parser.parse_args()
 if args.app == 'train':
     assert args.training_data_dir != ' ' and args.out_model_dir != ' '
     if args.validation_data_dir == ' ':
         args.validation_data_dir = args.training_data_dir
     misc.init_output_logging()
     np.random.seed(args.training_seed)
     print('Initializing TensorFlow...')
     os.environ.update(config.env)
     tfutil.init_tf(config.tf_config)
     if args.training_data_dir[-1] == '/':
         args.training_data_dir = args.training_data_dir[:-1]
     idx = args.training_data_dir.rfind('/')
     config.data_dir = args.training_data_dir[:idx]
     config.training_set = config.EasyDict(
         tfrecord_dir=args.training_data_dir[idx + 1:],
         max_label_size='full')
     if args.validation_data_dir[-1] == '/':
         args.validation_data_dir = args.validation_data_dir[:-1]
     idx = args.validation_data_dir.rfind('/')
     config.validation_set = config.EasyDict(
         tfrecord_dir=args.validation_data_dir[idx + 1:],
         max_label_size='full')
     app = config.EasyDict(func='run.train_classifier',
                           lr_mirror_augment=True,
                           ud_mirror_augment=False,
                           total_kimg=25000)
     config.result_dir = args.out_model_dir
 elif args.app == 'test':
     assert args.model_path != ' ' and args.testing_data_path != ' ' and args.out_fingerprint_dir != ' '
     misc.init_output_logging()