コード例 #1
0
ファイル: eval.py プロジェクト: foamtsp/CVProject
def eval_ckpt():
    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--result_dir", help="path to save the result file")
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)
    img_dir = Path(args.result_dir)
    img_dir.mkdir(parents=True, exist_ok=True)

    trn_transform, val_transform = setup_transforms(cfg)

    g_kwargs = cfg.get('g_args', {})
    gen = Generator(1, cfg.C, 1, **g_kwargs).cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    test_dset, test_loader = get_test_loader(cfg, val_transform)

    for batch in test_loader:
        style_imgs = batch["style_imgs"].cuda()
        char_imgs = batch["source_imgs"].unsqueeze(1).cuda()

        out = gen.gen_from_style_char(style_imgs, char_imgs)
        fonts = batch["fonts"]
        chars = batch["chars"]

        for image, font, char in zip(refine(out), fonts, chars):
            (img_dir / font).mkdir(parents=True, exist_ok=True)
            path = img_dir / font / f"{char}.png"
            save_tensor_to_image(image, path)
コード例 #2
0
def test_loads(tmp_path, train_dic):
    train_path = tmp_path / 'train.yaml'
    yaml.dump(train_dic, train_path)

    cfg_by_dic = Config(train_dic)
    cfg_by_path = Config(train_path)
    cfg_by_fp = Config(open(train_path, encoding='utf-8'))

    assert cfg_by_dic == cfg_by_path == cfg_by_fp
コード例 #3
0
ファイル: test_dumps.py プロジェクト: khanrc/sconf
def test_dumps_quote():
    dic = yaml.load("""
        a: null
        b: None
        c: 1
        d: 1.1
    """)
    cfg = Config(dic, colorize_modified_item=True)

    assert cfg.dumps(quote_str=True) == "a: None\nb: 'None'\nc: 1\nd: 1.1"
コード例 #4
0
ファイル: generator.py プロジェクト: foamtsp/CVProject
    def __init__(self):
        cfg = Config('model/mxfont/cfgs/defaults.yaml')
        g_kwargs = cfg.get('g_args', {})

        self.gen = Generator(1, cfg.C, 1, **g_kwargs).cuda()

        weight = torch.load('model/mxfont/fontgen.pth')

        if "generator_ema" in weight:
            weight = weight["generator_ema"]
        self.gen.load_state_dict(weight)

        _, self.val_transform = setup_transforms(cfg)
コード例 #5
0
def test_none_cli():
    dic = yaml.load("""
        a: 1
        b: 2
        c: 3
    """)
    cfg = Config(dic)
    cfg.argv_update(
        ['--a', 'None', '--b', 'NONE', '--c', 'none', '--d', 'NOne'])

    assert cfg['a'] == 'None'
    assert cfg['b'] == 'NONE'
    assert cfg['c'] == 'none'
    assert cfg['d'] == 'NOne'
コード例 #6
0
def test_true_cli():
    dic = yaml.load("""
        test: True
        hmm: true
        a:
            q: TRUE
            w: tRuE
    """)
    cfg = Config(dic)
    cfg.argv_update(
        ['--test', 'true', '--hmm', 'True', '--a.q', 'trUE', '--a.w', 'TRUE'])

    assert cfg['test'] is True
    assert cfg['hmm'] is True
    assert cfg['a']['q'] == 'trUE'
    assert cfg['a']['w'] is True
コード例 #7
0
def test_merge():
    dic1 = yaml.load("""
        test: 1
        hmm: 2
        a:
            q: 1
            w: 2
    """)
    dic2 = yaml.load("""
        test: 2
        hmm: 3
        a:
            w: 8
            c: 6
        b:
            bq: 1
            bw: 2
    """)
    cfg = Config(dic1, dic2)

    assert cfg == {
        'test': 2,
        'hmm': 3,
        'a': {
            'q': 1,
            'w': 8,
            'c': 6
        },
        'b': {
            'bq': 1,
            'bw': 2
        }
    }
コード例 #8
0
def test_null_cli():
    dic = yaml.load("""
        a: 1
        b: 2
        c: 3
    """)
    cfg = Config(dic)
    cfg.argv_update([
        '--a', 'null', '--b', 'Null', '--c', 'NULL', '--d', 'NUll', '--e',
        '\'Null\''
    ])

    assert cfg['a'] is None
    assert cfg['b'] is None
    assert cfg['c'] is None
    assert cfg['d'] == 'NUll'
    assert cfg['e'] == 'Null'
コード例 #9
0
def test_false_cli():
    dic = yaml.load("""
        a: false
        b: FAlse
        c: faLSE
    """)
    cfg = Config(dic)
    cfg.argv_update([
        '--a', 'False', '--b', 'false', '--c', 'FALSE', '--d', 'Fals', '--e',
        'FAlse'
    ])

    assert cfg['a'] is False
    assert cfg['b'] is False
    assert cfg['c'] is False
    assert cfg['d'] == 'Fals'
    assert cfg['e'] == 'FAlse'
コード例 #10
0
def test_load_from_filepath(tmp_path, train_dic, train_cfg, data_dic, data_cfg,
                            merge_cfg):
    train_path = tmp_path / 'train.yaml'
    data_path = tmp_path / 'data.yaml'
    yaml.dump(train_dic, train_path)
    yaml.dump(data_dic, data_path)

    cfg = Config(train_path)
    assert train_cfg == cfg

    cfg = Config(data_path)
    assert data_cfg == cfg

    cfg = Config(train_path, data_path)
    assert merge_cfg == cfg

    cfg = Config(data_path, default=train_path)
    assert merge_cfg == cfg
コード例 #11
0
def test_false():
    dic = yaml.load("""
        a: false
        b: FAlse
        c: faLSE
    """)
    cfg = Config(dic)

    assert cfg['a'] is False
    assert cfg['b'] == 'FAlse'
    assert cfg['c'] == 'faLSE'
コード例 #12
0
def test_none():
    dic = yaml.load("""
        a: None
        b: none
        c: NONE
    """)
    cfg = Config(dic)

    assert cfg['a'] == 'None'
    assert cfg['b'] == 'none'
    assert cfg['c'] == 'NONE'
コード例 #13
0
ファイル: test_dumps.py プロジェクト: khanrc/sconf
def test_dumps():
    dic = yaml.load("""
        test: 1
        hmm: 2
        a:
            q: 1
            w: 2
        b:
            - 1
            - 2
            - 3
        c:
            - a: 10
              b: 10
            - q: 20
              w: 20
    """)
    cfg = Config(dic, colorize_modified_item=False)

    dic2 = yaml.load(cfg.dumps())

    assert dic == dic2
コード例 #14
0
def test_null():
    dic = yaml.load("""
        d: Null
        e: null
        f: NULL
        g: NUll
    """)
    cfg = Config(dic)

    assert cfg['d'] is None
    assert cfg['e'] is None
    assert cfg['f'] is None
    assert cfg['g'] == 'NUll'
コード例 #15
0
def test_true():
    dic = yaml.load("""
        test: True
        hmm: true
        a:
            q: TRUE
            w: tRuE
    """)
    cfg = Config(dic)

    assert cfg['test'] is True
    assert cfg['hmm'] is True
    assert cfg['a']['q'] is True
    assert cfg['a']['w'] == 'tRuE'
コード例 #16
0
def test_merge_order(train_dic, data_dic):
    cfg = Config(train_dic, data_dic)
    cfg_reverse = Config(data_dic, train_dic)

    # same content
    assert dict(cfg) == dict(cfg_reverse)

    # different order
    assert list(cfg.keys()) != list(cfg_reverse.keys())
コード例 #17
0
def setup_args_and_config():
    parser = argparse.ArgumentParser('MaHFG')
    parser.add_argument("name")
    parser.add_argument("config_paths", nargs="+")
    parser.add_argument("--show", action="store_true", default=False)
    parser.add_argument("--resume", default=None)
    parser.add_argument("--finetune", default=None)
    parser.add_argument("--log_lv", default='info')
    parser.add_argument("--debug", default=False, action="store_true")
    parser.add_argument("--tb-image",
                        default=False,
                        action="store_true",
                        help="Write image log to tensorboard")
    parser.add_argument("--deterministic", default=False, action="store_true")

    args, left_argv = parser.parse_known_args()
    assert not args.name.endswith(".yaml")

    cfg = Config(*args.config_paths, colorize_modified_item=True)
    cfg.argv_update(left_argv)

    if args.debug:
        cfg['print_freq'] = 1
        cfg['tb_freq'] = 1
        cfg['max_iter'] = 10
        #  cfg['save'] = 'last'
        cfg['val_freq'] = 5
        cfg['save_freq'] = 10
        args.name += "_debug"
        args.tb_image = True
        args.log_lv = 'debug'

    cfg['data_dir'] = Path(cfg['data_dir'])

    assert cfg['save_freq'] % cfg['val_freq'] == 0

    return args, cfg
コード例 #18
0
ファイル: test_dotaccess.py プロジェクト: khanrc/sconf
def test_duplicated_key_for_container():
    """ Duplicated key test for config container """
    # With the dot-access interface, we cannot use the duplicated keys with object method name.
    duplicated_dic = {
        "get": 1,
        "items": 2,
    }
    cfg = Config(duplicated_dic)

    assert cfg['get'] == 1
    assert cfg['items'] == 2

    # dot access to the duplicated key should return the method but value.
    assert cfg.get != 1 and callable(cfg.get)
    assert cfg.items != 2 and callable(cfg.items)
コード例 #19
0
ファイル: test_dumps.py プロジェクト: khanrc/sconf
def test_dumps_coloring():
    dic = yaml.load("""
        a: 10
        b: 20
    """)
    cfg = Config(dic, colorize_modified_item=True)
    cfg.argv_update(['--a', '20'])

    assert cfg.dumps() == "\033[36ma: 20\n\033[0mb: 20"
    assert cfg.dumps() == "\x1b[36ma: 20\n\x1b[0mb: 20"  # hexa
    assert cfg.dumps(modified_color=None) == "a: 20\nb: 20"
コード例 #20
0
ファイル: test_dotaccess.py プロジェクト: khanrc/sconf
def test_duplicated_key_for_dotdict():
    """ Duplicated key test for inner data structure (e.g. Munch) """
    duplicated_dic = {
        't': {
            "get": 1,
            "items": 2,
            "update": 3,
        }
    }
    cfg = Config(duplicated_dic)

    assert cfg['t']['get'] == 1
    assert cfg['t']['items'] == 2
    assert cfg['t']['update'] == 3

    # dot access to the duplicated key should return the method but value.
    assert cfg.t.get != 1 and callable(cfg.t.get)
    assert cfg.t.items != 2 and callable(cfg.t.items)
    assert cfg.t.update != 3 and callable(cfg.t.update)
コード例 #21
0
ファイル: train.py プロジェクト: yqGANs/lffont
def setup_args_and_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("name")
    parser.add_argument("config_paths", nargs="+", help="path/to/config.yaml")
    parser.add_argument("--resume", default=None, help="path/to/saved/.pth")
    parser.add_argument("--use_unique_name", default=False, action="store_true", help="whether to use name with timestamp")

    args, left_argv = parser.parse_known_args()
    assert not args.name.endswith(".yaml")

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml",
                 colorize_modified_item=True)
    cfg.argv_update(left_argv)

    if cfg.use_ddp:
        cfg.n_workers = 0

    cfg.work_dir = Path(cfg.work_dir)
    cfg.work_dir.mkdir(parents=True, exist_ok=True)

    if args.use_unique_name:
        timestamp = utils.timestamp()
        unique_name = "{}_{}".format(timestamp, args.name)
    else:
        unique_name = args.name

    cfg.unique_name = unique_name
    cfg.name = args.name

    (cfg.work_dir / "logs").mkdir(parents=True, exist_ok=True)
    (cfg.work_dir / "checkpoints" / unique_name).mkdir(parents=True, exist_ok=True)

    if cfg.save_freq % cfg.val_freq:
        raise ValueError("save_freq has to be multiple of val_freq.")

    return args, cfg
コード例 #22
0
def setup_args_and_config():
    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path/to/config.yaml")

    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths,
                 default="cfgs/defaults.yaml",
                 colorize_modified_item=True)
    cfg.argv_update(left_argv)

    if cfg.use_ddp:
        cfg.n_workers = 0

    cfg.work_dir = Path(cfg.work_dir)
    (cfg.work_dir / "checkpoints").mkdir(parents=True, exist_ok=True)

    return args, cfg
コード例 #23
0
def test_is_maintained():
    cfg = Config.get_default()
    assert cfg.lr == 0.1
コード例 #24
0
def test_reg(train_cfg):
    cfg = Config.get_default()

    assert train_cfg == cfg
コード例 #25
0
def test_reg2(train_cfg, data_cfg):
    cfg = Config.get_default()

    assert train_cfg == cfg
    assert data_cfg != cfg
コード例 #26
0
def test_modify_data():
    cfg = Config.get_default()
    assert cfg.lr == 0.001
    cfg.lr = 0.1
コード例 #27
0
def eval_ckpt():
    import argparse
    from models import generator_dispatch
    from sconf import Config
    from train import setup_transforms
    from datasets import load_json, get_fact_test_loader

    logger = Logger.get()

    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--img_dir", help="path to save images for evaluation")
    parser.add_argument(
        "--test_meta",
        help=
        "path to metafile: contains (font, chars (in unicode)) to generate and reference chars (in unicode)"
    )
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)
    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(
        read_data_from_lmdb(env, f'{x}_{y}')['img'])

    test_meta = load_json(args.test_meta)
    dec_dict = load_json(cfg.dec_dict)

    g_kwargs = cfg.get('g_args', {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    logger.info(f"Resumed checkpoint from {args.weight}")
    writer = None

    evaluator = Evaluator(env, env_get, logger, writer, cfg["batch_size"],
                          val_transform, content_font)

    img_dir = Path(args.img_dir)
    ref_unis = test_meta["ref_unis"]
    gen_unis = test_meta["gen_unis"]
    gen_fonts = test_meta["gen_fonts"]
    target_dict = {f: gen_unis for f in gen_fonts}

    loader = get_fact_test_loader(env,
                                  env_get,
                                  target_dict,
                                  ref_unis,
                                  cfg,
                                  None,
                                  dec_dict,
                                  val_transform,
                                  ret_targets=False,
                                  num_workers=cfg.n_workers,
                                  shuffle=False)[1]

    logger.info("Save CV results to {} ...".format(img_dir))
    evaluator.save_each_imgs(gen,
                             loader,
                             save_dir=img_dir,
                             phase="fact",
                             reduction='mean')
コード例 #28
0
ファイル: evaluator.py プロジェクト: peternara/dmfont-gan-ocr
def eval_ckpt():
    from train import (setup_language_dependent, setup_data,
                       setup_cv_dset_loader, get_dset_loader)

    logger = Logger.get()

    parser = argparse.ArgumentParser('MaHFG-eval')
    parser.add_argument(
        "name",
        help=
        "name is used for directory name of the user-study generation results")
    parser.add_argument("resume")
    parser.add_argument("img_dir")
    parser.add_argument("config_paths", nargs="+")
    parser.add_argument("--show", action="store_true", default=False)
    parser.add_argument(
        "--mode",
        default="eval",
        help="eval (default) / cv-save / user-study / user-study-save. "
        "`eval` generates comparable grid and computes pixel-level CV scores. "
        "`cv-save` generates and saves all target characters in CV. "
        "`user-study` generates comparable grid for the ramdomly sampled target characters. "
        "`user-study-save` generates and saves all target characters in user-study."
    )
    parser.add_argument("--deterministic", default=False, action="store_true")
    parser.add_argument("--debug", default=False, action="store_true")
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths)
    cfg.argv_update(left_argv)

    torch.backends.cudnn.benchmark = True

    cfg['data_dir'] = Path(cfg['data_dir'])

    if args.show:
        exit()

    # seed
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    random.seed(cfg['seed'])

    if args.deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        cfg['n_workers'] = 0
        logger.info("#" * 80)
        logger.info("# Deterministic option is activated !")
        logger.info(
            "# Deterministic evaluator only ensure the deterministic cross-validation"
        )
        logger.info("#" * 80)
    else:
        torch.backends.cudnn.benchmark = True

    if args.mode.startswith('mix'):
        assert cfg['g_args']['style_enc']['use'], \
                "Style mixing is only available with style encoder model"

    #####################################
    # Dataset
    ####################################
    # setup language dependent values
    content_font, n_comp_types, n_comps = setup_language_dependent(cfg)

    # setup transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # setup data
    hdf5_data, meta = setup_data(cfg, transform)

    # setup dataset
    trn_dset, loader = get_dset_loader(hdf5_data,
                                       meta['train']['fonts'],
                                       meta['train']['chars'],
                                       transform,
                                       True,
                                       cfg,
                                       content_font=content_font)

    val_loaders = setup_cv_dset_loader(hdf5_data, meta, transform,
                                       n_comp_types, content_font, cfg)

    #####################################
    # Model
    ####################################
    # setup generator only
    g_kwargs = cfg.get('g_args', {})
    gen = MACore(1,
                 cfg['C'],
                 1,
                 **g_kwargs,
                 n_comps=n_comps,
                 n_comp_types=n_comp_types,
                 language=cfg['language'])
    gen.cuda()

    ckpt = torch.load(args.resume)
    logger.info("Use EMA generator as default")
    gen.load_state_dict(ckpt['generator_ema'])

    step = ckpt['epoch']
    loss = ckpt['loss']

    logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
        args.resume, step, loss))

    writer = utils.DiskWriter(args.img_dir, 0.6)

    evaluator = Evaluator(hdf5_data,
                          trn_dset.avails,
                          logger,
                          writer,
                          cfg['batch_size'],
                          content_font=content_font,
                          transform=transform,
                          language=cfg['language'],
                          val_loaders=val_loaders,
                          meta=meta)
    evaluator.n_cv_batches = -1
    logger.info("Update n_cv_batches = -1 to evaluate about full data")
    if args.debug:
        evaluator.n_cv_batches = 10
        logger.info("!!! DEBUG MODE: n_cv_batches = 10 !!!")

    if args.mode == 'eval':
        logger.info("Start validation ...")
        dic = evaluator.validation(gen, step)
        logger.info("Validation is done. Result images are saved to {}".format(
            args.img_dir))
    elif args.mode.startswith('user-study'):
        meta = json.load(open('meta/kor-unrefined.json'))
        target_chars = meta['target_chars']
        style_chars = meta['style_chars']
        fonts = meta['fonts']

        if args.mode == 'user-study':
            sampled_target_chars = uniform_sample(target_chars, 20)
            logger.info("Start generation kor-unrefined ...")
            logger.info("Sampled chars = {}".format(sampled_target_chars))

            evaluator.handwritten_validation_2stage(gen,
                                                    step,
                                                    fonts,
                                                    style_chars,
                                                    sampled_target_chars,
                                                    comparable=True,
                                                    tag='userstudy-{}'.format(
                                                        args.name))
        elif args.mode == 'user-study-save':
            logger.info("Start generation & saving kor-unrefined ...")
            save_dir = Path(args.img_dir) / "{}-{}".format(args.name, step)
            evaluator.handwritten_validation_2stage(gen,
                                                    step,
                                                    fonts,
                                                    style_chars,
                                                    target_chars,
                                                    comparable=True,
                                                    save_dir=save_dir)
        logger.info("Validation is done. Result images are saved to {}".format(
            args.img_dir))
    elif args.mode == 'cv-save':
        save_dir = Path(args.img_dir) / "cv_images_{}".format(step)
        logger.info("Save CV results to {} ...".format(save_dir))
        utils.rm(save_dir)
        for tag, loader in val_loaders.items():
            l1, ssim, msssim = evaluator.cross_validation(
                gen,
                step,
                loader,
                tag,
                n_batches=evaluator.n_cv_batches,
                save_dir=(save_dir / tag))
    else:
        raise ValueError(args.mode)
コード例 #29
0
def test_new_registration(train_cfg, data_cfg):
    registry.register(data_cfg, 'data')
    cfg = Config.from_registry('data')

    assert train_cfg != cfg
    assert data_cfg == cfg
コード例 #30
0
def test_is_diff_to_given(train_cfg):
    cfg = Config.get_default()
    assert train_cfg != cfg
    cfg.lr = 0.001
    assert train_cfg == cfg
    assert id(train_cfg) != id(cfg)