def eval_ckpt(): import argparse from models import generator_dispatch from sconf import Config from train import setup_transforms from datasets import load_json, get_fact_test_loader logger = Logger.get() parser = argparse.ArgumentParser() parser.add_argument("config_paths", nargs="+", help="path to config.yaml") parser.add_argument("--weight", help="path to weight to evaluate.pth") parser.add_argument("--img_dir", help="path to save images for evaluation") parser.add_argument( "--test_meta", help= "path to metafile: contains (font, chars (in unicode)) to generate and reference chars (in unicode)" ) args, left_argv = parser.parse_known_args() cfg = Config(*args.config_paths, default="cfgs/defaults.yaml") cfg.argv_update(left_argv) content_font = cfg.content_font n_comps = int(cfg.n_comps) trn_transform, val_transform = setup_transforms(cfg) env = load_lmdb(cfg.data_path) env_get = lambda env, x, y, transform: transform( read_data_from_lmdb(env, f'{x}_{y}')['img']) test_meta = load_json(args.test_meta) dec_dict = load_json(cfg.dec_dict) g_kwargs = cfg.get('g_args', {}) g_cls = generator_dispatch() gen = g_cls(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps) gen.cuda() weight = torch.load(args.weight) if "generator_ema" in weight: weight = weight["generator_ema"] gen.load_state_dict(weight) logger.info(f"Resumed checkpoint from {args.weight}") writer = None evaluator = Evaluator(env, env_get, logger, writer, cfg["batch_size"], val_transform, content_font) img_dir = Path(args.img_dir) ref_unis = test_meta["ref_unis"] gen_unis = test_meta["gen_unis"] gen_fonts = test_meta["gen_fonts"] target_dict = {f: gen_unis for f in gen_fonts} loader = get_fact_test_loader(env, env_get, target_dict, ref_unis, cfg, None, dec_dict, val_transform, ret_targets=False, num_workers=cfg.n_workers, shuffle=False)[1] logger.info("Save CV results to {} ...".format(img_dir)) evaluator.save_each_imgs(gen, loader, save_dir=img_dir, phase="fact", reduction='mean')
def train(args, cfg, ddp_gpu=-1): cfg.gpu = ddp_gpu torch.cuda.set_device(ddp_gpu) cudnn.benchmark = True logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name) logger = Logger.get(file_path=logger_path, level="info", colorize=True) image_scale = 0.6 writer_path = cfg.work_dir / "runs" / cfg.unique_name image_path = cfg.work_dir / "images" / cfg.unique_name writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale) args_str = dump_args(args) if is_main_worker(ddp_gpu): logger.info("Run Argv:\n> {}".format(" ".join(sys.argv))) logger.info("Args:\n{}".format(args_str)) logger.info("Configs:\n{}".format(cfg.dumps())) logger.info("Unique name: {}".format(cfg.unique_name)) logger.info("Get dataset ...") content_font = cfg.content_font n_comps = int(cfg.n_comps) trn_transform, val_transform = setup_transforms(cfg) env = load_lmdb(cfg.data_path) env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img']) data_meta = load_json(cfg.data_meta) dec_dict = load_json(cfg.dec_dict) if cfg.phase == "comb": get_trn_loader = get_comb_trn_loader get_cv_loaders = get_cv_comb_loaders Trainer = CombinedTrainer elif cfg.phase == "fact": get_trn_loader = get_fact_trn_loader get_cv_loaders = get_cv_fact_loaders Trainer = FactorizeTrainer else: raise ValueError(cfg.phase) trn_dset, trn_loader = get_trn_loader(env, env_get, cfg, data_meta["train"], dec_dict, trn_transform, num_workers=cfg.n_workers, shuffle=True) if is_main_worker(ddp_gpu): cv_loaders = get_cv_loaders(env, env_get, cfg, data_meta, dec_dict, val_transform, num_workers=cfg.n_workers, shuffle=False) else: cv_loaders = None logger.info("Build model ...") # generator g_kwargs = cfg.get("g_args", {}) g_cls = generator_dispatch() gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps) gen.cuda() gen.apply(weights_init(cfg.init)) if cfg.gan_w > 0.: d_kwargs = cfg.get("d_args", {}) disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs) disc.cuda() disc.apply(weights_init(cfg.init)) else: disc = None if cfg.ac_w > 0.: aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args) aux_clf.cuda() aux_clf.apply(weights_init(cfg.init)) else: aux_clf = None assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss" g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas) d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \ if disc is not None else None ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \ if aux_clf is not None else None st_step = 1 if args.resume: st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite) logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format( args.resume, st_step - 1, loss)) if cfg.overwrite: st_step = 1 else: pass evaluator = Evaluator(env, env_get, logger, writer, cfg.batch_size, val_transform, content_font, use_half=cfg.use_half ) trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer, logger, evaluator, cv_loaders, cfg) trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])