def setUp(self): self.trainset = config.train_name self.valset = config.val_name self.testset = config.test_name self.tag_padding = -1 self.params = Params(config.datasets_params_file) self.params.embedding_dim = 50 self.params.lstm_hidden_dim = 50 self.loader = DataLoader(config.data_dir, self.params) self.logger = Logger.get()
def __init__(self, datasets_params_file, data_factor=1.0, train_factor=0.7, val_factor=0.15, test_factor=0.15, train_name='train', val_name='val', test_name='test'): """ Args: datasets_params_file: the file will be used to save datasets parameters. """ self.datasets_params_file = datasets_params_file self.data_factor = data_factor self.train_factor = train_factor self.val_factor = val_factor self.test_factor = test_factor self.train_name = train_name self.val_name = val_name self.test_name = test_name self.samples = [] self.logger = Logger.get()
def evaluate(model, dateset, loss_fn, metrics): # set model to evaluation mode model.eval() logger = Logger.get() running_avg = RunningAvg() for batch in dateset: inputs, targets = batch stat = {} with torch.no_grad(): outputs = model(inputs) loss = loss_fn(outputs, targets) stat['loss'] = loss.item() for name, metric in metrics.items(): stat[name] = metric(outputs, targets).item() running_avg.step(stat) metrics_mean = running_avg() logger.info("- Evaluation metrics:") for name, value in metrics_mean.items(): logger.info(' * {}: {:05.3f}'.format(name, value)) return metrics_mean
def eval_ckpt(): import argparse from models import generator_dispatch from sconf import Config from train import setup_transforms from datasets import load_json, get_fact_test_loader logger = Logger.get() parser = argparse.ArgumentParser() parser.add_argument("config_paths", nargs="+", help="path to config.yaml") parser.add_argument("--weight", help="path to weight to evaluate.pth") parser.add_argument("--img_dir", help="path to save images for evaluation") parser.add_argument( "--test_meta", help= "path to metafile: contains (font, chars (in unicode)) to generate and reference chars (in unicode)" ) args, left_argv = parser.parse_known_args() cfg = Config(*args.config_paths, default="cfgs/defaults.yaml") cfg.argv_update(left_argv) content_font = cfg.content_font n_comps = int(cfg.n_comps) trn_transform, val_transform = setup_transforms(cfg) env = load_lmdb(cfg.data_path) env_get = lambda env, x, y, transform: transform( read_data_from_lmdb(env, f'{x}_{y}')['img']) test_meta = load_json(args.test_meta) dec_dict = load_json(cfg.dec_dict) g_kwargs = cfg.get('g_args', {}) g_cls = generator_dispatch() gen = g_cls(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps) gen.cuda() weight = torch.load(args.weight) if "generator_ema" in weight: weight = weight["generator_ema"] gen.load_state_dict(weight) logger.info(f"Resumed checkpoint from {args.weight}") writer = None evaluator = Evaluator(env, env_get, logger, writer, cfg["batch_size"], val_transform, content_font) img_dir = Path(args.img_dir) ref_unis = test_meta["ref_unis"] gen_unis = test_meta["gen_unis"] gen_fonts = test_meta["gen_fonts"] target_dict = {f: gen_unis for f in gen_fonts} loader = get_fact_test_loader(env, env_get, target_dict, ref_unis, cfg, None, dec_dict, val_transform, ret_targets=False, num_workers=cfg.n_workers, shuffle=False)[1] logger.info("Save CV results to {} ...".format(img_dir)) evaluator.save_each_imgs(gen, loader, save_dir=img_dir, phase="fact", reduction='mean')
def train(args, cfg, ddp_gpu=-1): cfg.gpu = ddp_gpu torch.cuda.set_device(ddp_gpu) cudnn.benchmark = True logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name) logger = Logger.get(file_path=logger_path, level="info", colorize=True) image_scale = 0.6 writer_path = cfg.work_dir / "runs" / cfg.unique_name image_path = cfg.work_dir / "images" / cfg.unique_name writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale) args_str = dump_args(args) if is_main_worker(ddp_gpu): logger.info("Run Argv:\n> {}".format(" ".join(sys.argv))) logger.info("Args:\n{}".format(args_str)) logger.info("Configs:\n{}".format(cfg.dumps())) logger.info("Unique name: {}".format(cfg.unique_name)) logger.info("Get dataset ...") content_font = cfg.content_font n_comps = int(cfg.n_comps) trn_transform, val_transform = setup_transforms(cfg) env = load_lmdb(cfg.data_path) env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img']) data_meta = load_json(cfg.data_meta) dec_dict = load_json(cfg.dec_dict) if cfg.phase == "comb": get_trn_loader = get_comb_trn_loader get_cv_loaders = get_cv_comb_loaders Trainer = CombinedTrainer elif cfg.phase == "fact": get_trn_loader = get_fact_trn_loader get_cv_loaders = get_cv_fact_loaders Trainer = FactorizeTrainer else: raise ValueError(cfg.phase) trn_dset, trn_loader = get_trn_loader(env, env_get, cfg, data_meta["train"], dec_dict, trn_transform, num_workers=cfg.n_workers, shuffle=True) if is_main_worker(ddp_gpu): cv_loaders = get_cv_loaders(env, env_get, cfg, data_meta, dec_dict, val_transform, num_workers=cfg.n_workers, shuffle=False) else: cv_loaders = None logger.info("Build model ...") # generator g_kwargs = cfg.get("g_args", {}) g_cls = generator_dispatch() gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps) gen.cuda() gen.apply(weights_init(cfg.init)) if cfg.gan_w > 0.: d_kwargs = cfg.get("d_args", {}) disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs) disc.cuda() disc.apply(weights_init(cfg.init)) else: disc = None if cfg.ac_w > 0.: aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args) aux_clf.cuda() aux_clf.apply(weights_init(cfg.init)) else: aux_clf = None assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss" g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas) d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \ if disc is not None else None ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \ if aux_clf is not None else None st_step = 1 if args.resume: st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite) logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format( args.resume, st_step - 1, loss)) if cfg.overwrite: st_step = 1 else: pass evaluator = Evaluator(env, env_get, logger, writer, cfg.batch_size, val_transform, content_font, use_half=cfg.use_half ) trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer, logger, evaluator, cv_loaders, cfg) trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])
def train(args, cfg, ddp_gpu=-1): cfg.gpu = ddp_gpu torch.cuda.set_device(ddp_gpu) cudnn.benchmark = True logger_path = cfg.work_dir / "log.log" logger = Logger.get(file_path=logger_path, level="info", colorize=True) image_scale = 0.5 image_path = cfg.work_dir / "images" writer = utils.DiskWriter(image_path, scale=image_scale) cfg.tb_freq = -1 args_str = dump_args(args) if is_main_worker(ddp_gpu): logger.info("Run Argv:\n> {}".format(" ".join(sys.argv))) logger.info("Args:\n{}".format(args_str)) logger.info("Configs:\n{}".format(cfg.dumps())) logger.info("Get dataset ...") trn_transform, val_transform = setup_transforms(cfg) primals = json.load(open(cfg.primals)) decomposition = json.load(open(cfg.decomposition)) n_comps = len(primals) trn_dset, trn_loader = get_trn_loader(cfg.dset.train, primals, decomposition, trn_transform, use_ddp=cfg.use_ddp, batch_size=cfg.batch_size, num_workers=cfg.n_workers, shuffle=True) test_dset, test_loader = get_val_loader(cfg.dset.val, val_transform, batch_size=cfg.batch_size, num_workers=cfg.n_workers, shuffle=False) logger.info("Build model ...") # generator g_kwargs = cfg.get("g_args", {}) gen = Generator(1, cfg.C, 1, **g_kwargs) gen.cuda() gen.apply(weights_init(cfg.init)) d_kwargs = cfg.get("d_args", {}) disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_chars, **d_kwargs) disc.cuda() disc.apply(weights_init(cfg.init)) aux_clf = aux_clf_builder(gen.feat_shape["last"], trn_dset.n_fonts, n_comps, **cfg.ac_args) aux_clf.cuda() aux_clf.apply(weights_init(cfg.init)) g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas) d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) st_step = 0 if cfg.resume: st_step, loss = load_checkpoint(cfg.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.force_resume) logger.info( "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format( cfg.resume, st_step, loss)) evaluator = Evaluator(writer) trainer = FactTrainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer, logger, evaluator, test_loader, cfg) trainer.train(trn_loader, st_step, cfg.max_iter)