예제 #1
0
def eval_ckpt():
    import argparse
    from models import generator_dispatch
    from sconf import Config
    from train import setup_transforms
    from datasets import load_json, get_fact_test_loader

    logger = Logger.get()

    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--img_dir", help="path to save images for evaluation")
    parser.add_argument(
        "--test_meta",
        help=
        "path to metafile: contains (font, chars (in unicode)) to generate and reference chars (in unicode)"
    )
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)
    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(
        read_data_from_lmdb(env, f'{x}_{y}')['img'])

    test_meta = load_json(args.test_meta)
    dec_dict = load_json(cfg.dec_dict)

    g_kwargs = cfg.get('g_args', {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    logger.info(f"Resumed checkpoint from {args.weight}")
    writer = None

    evaluator = Evaluator(env, env_get, logger, writer, cfg["batch_size"],
                          val_transform, content_font)

    img_dir = Path(args.img_dir)
    ref_unis = test_meta["ref_unis"]
    gen_unis = test_meta["gen_unis"]
    gen_fonts = test_meta["gen_fonts"]
    target_dict = {f: gen_unis for f in gen_fonts}

    loader = get_fact_test_loader(env,
                                  env_get,
                                  target_dict,
                                  ref_unis,
                                  cfg,
                                  None,
                                  dec_dict,
                                  val_transform,
                                  ret_targets=False,
                                  num_workers=cfg.n_workers,
                                  shuffle=False)[1]

    logger.info("Save CV results to {} ...".format(img_dir))
    evaluator.save_each_imgs(gen,
                             loader,
                             save_dir=img_dir,
                             phase="fact",
                             reduction='mean')
예제 #2
0
파일: train.py 프로젝트: yqGANs/lffont
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name)
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.6
    writer_path = cfg.work_dir / "runs" / cfg.unique_name
    image_path = cfg.work_dir / "images" / cfg.unique_name
    writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))
        logger.info("Unique name: {}".format(cfg.unique_name))

    logger.info("Get dataset ...")

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)

    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img'])

    data_meta = load_json(cfg.data_meta)
    dec_dict = load_json(cfg.dec_dict)

    if cfg.phase == "comb":
        get_trn_loader = get_comb_trn_loader
        get_cv_loaders = get_cv_comb_loaders
        Trainer = CombinedTrainer

    elif cfg.phase == "fact":
        get_trn_loader = get_fact_trn_loader
        get_cv_loaders = get_cv_fact_loaders
        Trainer = FactorizeTrainer

    else:
        raise ValueError(cfg.phase)

    trn_dset, trn_loader = get_trn_loader(env,
                                          env_get,
                                          cfg,
                                          data_meta["train"],
                                          dec_dict,
                                          trn_transform,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    if is_main_worker(ddp_gpu):
        cv_loaders = get_cv_loaders(env,
                                    env_get,
                                    cfg,
                                    data_meta,
                                    dec_dict,
                                    val_transform,
                                    num_workers=cfg.n_workers,
                                    shuffle=False)
    else:
        cv_loaders = None

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    if cfg.gan_w > 0.:
        d_kwargs = cfg.get("d_args", {})
        disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs)
        disc.cuda()
        disc.apply(weights_init(cfg.init))
    else:
        disc = None

    if cfg.ac_w > 0.:
        aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args)
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg.init))
    else:
        aux_clf = None
        assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss"

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \
        if disc is not None else None
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \
        if aux_clf is not None else None

    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite)
        logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
            args.resume, st_step - 1, loss))
        if cfg.overwrite:
            st_step = 1
        else:
            pass

    evaluator = Evaluator(env,
                          env_get,
                          logger,
                          writer,
                          cfg.batch_size,
                          val_transform,
                          content_font,
                          use_half=cfg.use_half
                          )

    trainer = Trainer(gen, disc, g_optim, d_optim,
                      aux_clf, ac_optim,
                      writer, logger,
                      evaluator, cv_loaders,
                      cfg)

    trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])