예제 #1
0
파일: eval.py 프로젝트: foamtsp/CVProject
def eval_ckpt():
    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--result_dir", help="path to save the result file")
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)
    img_dir = Path(args.result_dir)
    img_dir.mkdir(parents=True, exist_ok=True)

    trn_transform, val_transform = setup_transforms(cfg)

    g_kwargs = cfg.get('g_args', {})
    gen = Generator(1, cfg.C, 1, **g_kwargs).cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    test_dset, test_loader = get_test_loader(cfg, val_transform)

    for batch in test_loader:
        style_imgs = batch["style_imgs"].cuda()
        char_imgs = batch["source_imgs"].unsqueeze(1).cuda()

        out = gen.gen_from_style_char(style_imgs, char_imgs)
        fonts = batch["fonts"]
        chars = batch["chars"]

        for image, font, char in zip(refine(out), fonts, chars):
            (img_dir / font).mkdir(parents=True, exist_ok=True)
            path = img_dir / font / f"{char}.png"
            save_tensor_to_image(image, path)
예제 #2
0
    def __init__(self):
        cfg = Config('model/mxfont/cfgs/defaults.yaml')
        g_kwargs = cfg.get('g_args', {})

        self.gen = Generator(1, cfg.C, 1, **g_kwargs).cuda()

        weight = torch.load('model/mxfont/fontgen.pth')

        if "generator_ema" in weight:
            weight = weight["generator_ema"]
        self.gen.load_state_dict(weight)

        _, self.val_transform = setup_transforms(cfg)
예제 #3
0
def eval_ckpt():
    import argparse
    from models import generator_dispatch
    from sconf import Config
    from train import setup_transforms
    from datasets import load_json, get_fact_test_loader

    logger = Logger.get()

    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--img_dir", help="path to save images for evaluation")
    parser.add_argument(
        "--test_meta",
        help=
        "path to metafile: contains (font, chars (in unicode)) to generate and reference chars (in unicode)"
    )
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)
    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(
        read_data_from_lmdb(env, f'{x}_{y}')['img'])

    test_meta = load_json(args.test_meta)
    dec_dict = load_json(cfg.dec_dict)

    g_kwargs = cfg.get('g_args', {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    logger.info(f"Resumed checkpoint from {args.weight}")
    writer = None

    evaluator = Evaluator(env, env_get, logger, writer, cfg["batch_size"],
                          val_transform, content_font)

    img_dir = Path(args.img_dir)
    ref_unis = test_meta["ref_unis"]
    gen_unis = test_meta["gen_unis"]
    gen_fonts = test_meta["gen_fonts"]
    target_dict = {f: gen_unis for f in gen_fonts}

    loader = get_fact_test_loader(env,
                                  env_get,
                                  target_dict,
                                  ref_unis,
                                  cfg,
                                  None,
                                  dec_dict,
                                  val_transform,
                                  ret_targets=False,
                                  num_workers=cfg.n_workers,
                                  shuffle=False)[1]

    logger.info("Save CV results to {} ...".format(img_dir))
    evaluator.save_each_imgs(gen,
                             loader,
                             save_dir=img_dir,
                             phase="fact",
                             reduction='mean')
예제 #4
0
def eval_ckpt():
    from train import (setup_language_dependent, setup_data,
                       setup_cv_dset_loader, get_dset_loader)

    logger = Logger.get()

    parser = argparse.ArgumentParser('MaHFG-eval')
    parser.add_argument(
        "name",
        help=
        "name is used for directory name of the user-study generation results")
    parser.add_argument("resume")
    parser.add_argument("img_dir")
    parser.add_argument("config_paths", nargs="+")
    parser.add_argument("--show", action="store_true", default=False)
    parser.add_argument(
        "--mode",
        default="eval",
        help="eval (default) / cv-save / user-study / user-study-save. "
        "`eval` generates comparable grid and computes pixel-level CV scores. "
        "`cv-save` generates and saves all target characters in CV. "
        "`user-study` generates comparable grid for the ramdomly sampled target characters. "
        "`user-study-save` generates and saves all target characters in user-study."
    )
    parser.add_argument("--deterministic", default=False, action="store_true")
    parser.add_argument("--debug", default=False, action="store_true")
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths)
    cfg.argv_update(left_argv)

    torch.backends.cudnn.benchmark = True

    cfg['data_dir'] = Path(cfg['data_dir'])

    if args.show:
        exit()

    # seed
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    random.seed(cfg['seed'])

    if args.deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        cfg['n_workers'] = 0
        logger.info("#" * 80)
        logger.info("# Deterministic option is activated !")
        logger.info(
            "# Deterministic evaluator only ensure the deterministic cross-validation"
        )
        logger.info("#" * 80)
    else:
        torch.backends.cudnn.benchmark = True

    if args.mode.startswith('mix'):
        assert cfg['g_args']['style_enc']['use'], \
                "Style mixing is only available with style encoder model"

    #####################################
    # Dataset
    ####################################
    # setup language dependent values
    content_font, n_comp_types, n_comps = setup_language_dependent(cfg)

    # setup transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # setup data
    hdf5_data, meta = setup_data(cfg, transform)

    # setup dataset
    trn_dset, loader = get_dset_loader(hdf5_data,
                                       meta['train']['fonts'],
                                       meta['train']['chars'],
                                       transform,
                                       True,
                                       cfg,
                                       content_font=content_font)

    val_loaders = setup_cv_dset_loader(hdf5_data, meta, transform,
                                       n_comp_types, content_font, cfg)

    #####################################
    # Model
    ####################################
    # setup generator only
    g_kwargs = cfg.get('g_args', {})
    gen = MACore(1,
                 cfg['C'],
                 1,
                 **g_kwargs,
                 n_comps=n_comps,
                 n_comp_types=n_comp_types,
                 language=cfg['language'])
    gen.cuda()

    ckpt = torch.load(args.resume)
    logger.info("Use EMA generator as default")
    gen.load_state_dict(ckpt['generator_ema'])

    step = ckpt['epoch']
    loss = ckpt['loss']

    logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
        args.resume, step, loss))

    writer = utils.DiskWriter(args.img_dir, 0.6)

    evaluator = Evaluator(hdf5_data,
                          trn_dset.avails,
                          logger,
                          writer,
                          cfg['batch_size'],
                          content_font=content_font,
                          transform=transform,
                          language=cfg['language'],
                          val_loaders=val_loaders,
                          meta=meta)
    evaluator.n_cv_batches = -1
    logger.info("Update n_cv_batches = -1 to evaluate about full data")
    if args.debug:
        evaluator.n_cv_batches = 10
        logger.info("!!! DEBUG MODE: n_cv_batches = 10 !!!")

    if args.mode == 'eval':
        logger.info("Start validation ...")
        dic = evaluator.validation(gen, step)
        logger.info("Validation is done. Result images are saved to {}".format(
            args.img_dir))
    elif args.mode.startswith('user-study'):
        meta = json.load(open('meta/kor-unrefined.json'))
        target_chars = meta['target_chars']
        style_chars = meta['style_chars']
        fonts = meta['fonts']

        if args.mode == 'user-study':
            sampled_target_chars = uniform_sample(target_chars, 20)
            logger.info("Start generation kor-unrefined ...")
            logger.info("Sampled chars = {}".format(sampled_target_chars))

            evaluator.handwritten_validation_2stage(gen,
                                                    step,
                                                    fonts,
                                                    style_chars,
                                                    sampled_target_chars,
                                                    comparable=True,
                                                    tag='userstudy-{}'.format(
                                                        args.name))
        elif args.mode == 'user-study-save':
            logger.info("Start generation & saving kor-unrefined ...")
            save_dir = Path(args.img_dir) / "{}-{}".format(args.name, step)
            evaluator.handwritten_validation_2stage(gen,
                                                    step,
                                                    fonts,
                                                    style_chars,
                                                    target_chars,
                                                    comparable=True,
                                                    save_dir=save_dir)
        logger.info("Validation is done. Result images are saved to {}".format(
            args.img_dir))
    elif args.mode == 'cv-save':
        save_dir = Path(args.img_dir) / "cv_images_{}".format(step)
        logger.info("Save CV results to {} ...".format(save_dir))
        utils.rm(save_dir)
        for tag, loader in val_loaders.items():
            l1, ssim, msssim = evaluator.cross_validation(
                gen,
                step,
                loader,
                tag,
                n_batches=evaluator.n_cv_batches,
                save_dir=(save_dir / tag))
    else:
        raise ValueError(args.mode)