def test_dump_args(): parser = argparse.ArgumentParser("TEST") parser.add_argument("a") parser.add_argument("b") parser.add_argument("--beta") parser.add_argument("B") parser.add_argument("AA") parser.add_argument("--alpha") args = parser.parse_args(args="1 2 3 4 --beta 5 --alpha 6".split()) dumps = dump_args(args) assert dumps == "\n".join([ "a = 1", "b = 2", "beta = 5", "B = 3", "AA = 4", "alpha = 6" ])
def main(): ############################ # argument setup ############################ args, cfg = setup_args_and_config() if args.show: print("### Run Argv:\n> {}".format(' '.join(sys.argv))) print("### Run Arguments:") s = dump_args(args) print(s + '\n') print("### Configs:") print(cfg.dumps()) sys.exit() timestamp = utils.timestamp() unique_name = "{}_{}".format(timestamp, args.name) cfg['unique_name'] = unique_name # for save directory cfg['name'] = args.name utils.makedirs('logs') utils.makedirs(Path('checkpoints', unique_name)) # logger logger_path = Path('logs', f"{unique_name}.log") logger = Logger.get(file_path=logger_path, level=args.log_lv, colorize=True) # writer image_scale = 0.6 writer_path = Path('runs', unique_name) if args.tb_image: writer = utils.TBWriter(writer_path, scale=image_scale) else: image_path = Path('images', unique_name) writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale) # log default informations args_str = dump_args(args) logger.info("Run Argv:\n> {}".format(' '.join(sys.argv))) logger.info("Args:\n{}".format(args_str)) logger.info("Configs:\n{}".format(cfg.dumps())) logger.info("Unique name: {}".format(unique_name)) # seed np.random.seed(cfg['seed']) torch.manual_seed(cfg['seed']) random.seed(cfg['seed']) if args.deterministic: # https://discuss.pytorch.org/t/how-to-get-deterministic-behavior/18177/16 # https://pytorch.org/docs/stable/notes/randomness.html torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True cfg['n_workers'] = 0 logger.info("#" * 80) logger.info("# Deterministic option is activated !") logger.info("#" * 80) else: torch.backends.cudnn.benchmark = True ############################ # setup dataset & loader ############################ logger.info("Get dataset ...") # setup language dependent values content_font, n_comp_types, n_comps = setup_language_dependent(cfg) # setup transform transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # setup data hdf5_data, meta = setup_data(cfg, transform) # setup dataset trn_dset, loader = get_dset_loader(hdf5_data, meta['train']['fonts'], meta['train']['chars'], transform, True, cfg, content_font=content_font) logger.info("### Training dataset ###") logger.info("# of avail fonts = {}".format(trn_dset.n_fonts)) logger.info(f"Total {len(loader)} iterations per epochs") logger.info("# of avail items = {}".format(trn_dset.n_avails)) logger.info(f"#fonts = {trn_dset.n_fonts}, #chars = {trn_dset.n_chars}") val_loaders = setup_cv_dset_loader(hdf5_data, meta, transform, n_comp_types, content_font, cfg) sfuc_loader = val_loaders['SeenFonts-UnseenChars'] sfuc_dset = sfuc_loader.dataset ufsc_loader = val_loaders['UnseenFonts-SeenChars'] ufsc_dset = ufsc_loader.dataset ufuc_loader = val_loaders['UnseenFonts-UnseenChars'] ufuc_dset = ufuc_loader.dataset logger.info("### Cross-validation datasets ###") logger.info("Seen fonts, Unseen chars | " "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format( len(sfuc_dset), len(sfuc_dset.fonts), len(sfuc_dset.chars), len(sfuc_loader))) logger.info("Unseen fonts, Seen chars | " "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format( len(ufsc_dset), len(ufsc_dset.fonts), len(ufsc_dset.chars), len(ufsc_loader))) logger.info("Unseen fonts, Unseen chars | " "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format( len(ufuc_dset), len(ufuc_dset.fonts), len(ufuc_dset.chars), len(ufuc_loader))) ############################ # build model ############################ logger.info("Build model ...") # generator g_kwargs = cfg.get('g_args', {}) gen = MACore(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps, n_comp_types=n_comp_types, language=cfg['language']) gen.cuda() gen.apply(weights_init(cfg['init'])) d_kwargs = cfg.get('d_args', {}) disc = Discriminator(cfg['C'], trn_dset.n_fonts, trn_dset.n_chars, **d_kwargs) disc.cuda() disc.apply(weights_init(cfg['init'])) if cfg['ac_w'] > 0.: C = gen.mem_shape[0] aux_clf = AuxClassifier(C, n_comps, **cfg['ac_args']) aux_clf.cuda() aux_clf.apply(weights_init(cfg['init'])) else: aux_clf = None assert cfg[ 'ac_gen_w'] == 0., "ac_gen loss is only available with ac loss" # setup optimizer g_optim = optim.Adam(gen.parameters(), lr=cfg['g_lr'], betas=cfg['adam_betas']) d_optim = optim.Adam(disc.parameters(), lr=cfg['d_lr'], betas=cfg['adam_betas']) ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg['g_lr'], betas=cfg['adam_betas']) \ if aux_clf is not None else None # resume checkpoint st_step = 1 if args.resume: st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim) logger.info( "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format( args.resume, st_step - 1, loss)) if args.finetune: load_gen_checkpoint(args.finetune, gen) ############################ # setup validation ############################ evaluator = Evaluator(hdf5_data, trn_dset.avails, logger, writer, cfg['batch_size'], content_font=content_font, transform=transform, language=cfg['language'], val_loaders=val_loaders, meta=meta) if args.debug: evaluator.n_cv_batches = 10 logger.info("Change CV batches to 10 for debugging") ############################ # start training ############################ trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer, logger, evaluator, cfg) trainer.train(loader, st_step)
def train(args, cfg, ddp_gpu=-1): cfg.gpu = ddp_gpu torch.cuda.set_device(ddp_gpu) cudnn.benchmark = True logger_path = cfg.work_dir / "log.log" logger = Logger.get(file_path=logger_path, level="info", colorize=True) image_scale = 0.5 image_path = cfg.work_dir / "images" writer = utils.DiskWriter(image_path, scale=image_scale) cfg.tb_freq = -1 args_str = dump_args(args) if is_main_worker(ddp_gpu): logger.info("Run Argv:\n> {}".format(" ".join(sys.argv))) logger.info("Args:\n{}".format(args_str)) logger.info("Configs:\n{}".format(cfg.dumps())) logger.info("Get dataset ...") trn_transform, val_transform = setup_transforms(cfg) primals = json.load(open(cfg.primals)) decomposition = json.load(open(cfg.decomposition)) n_comps = len(primals) trn_dset, trn_loader = get_trn_loader(cfg.dset.train, primals, decomposition, trn_transform, use_ddp=cfg.use_ddp, batch_size=cfg.batch_size, num_workers=cfg.n_workers, shuffle=True) test_dset, test_loader = get_val_loader(cfg.dset.val, val_transform, batch_size=cfg.batch_size, num_workers=cfg.n_workers, shuffle=False) logger.info("Build model ...") # generator g_kwargs = cfg.get("g_args", {}) gen = Generator(1, cfg.C, 1, **g_kwargs) gen.cuda() gen.apply(weights_init(cfg.init)) d_kwargs = cfg.get("d_args", {}) disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_chars, **d_kwargs) disc.cuda() disc.apply(weights_init(cfg.init)) aux_clf = aux_clf_builder(gen.feat_shape["last"], trn_dset.n_fonts, n_comps, **cfg.ac_args) aux_clf.cuda() aux_clf.apply(weights_init(cfg.init)) g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas) d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) st_step = 0 if cfg.resume: st_step, loss, g_optim, d_optim, ac_optim = load_checkpoint( cfg.resume, gen, disc, aux_clf, cfg) logger.info( "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format( cfg.resume, st_step, loss)) evaluator = Evaluator(writer) trainer = FactTrainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer, logger, evaluator, test_loader, cfg) trainer.train(trn_loader, st_step, cfg.max_iter)
def train(args, cfg, ddp_gpu=-1): cfg.gpu = ddp_gpu torch.cuda.set_device(ddp_gpu) cudnn.benchmark = True logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name) logger = Logger.get(file_path=logger_path, level="info", colorize=True) image_scale = 0.6 writer_path = cfg.work_dir / "runs" / cfg.unique_name image_path = cfg.work_dir / "images" / cfg.unique_name writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale) args_str = dump_args(args) if is_main_worker(ddp_gpu): logger.info("Run Argv:\n> {}".format(" ".join(sys.argv))) logger.info("Args:\n{}".format(args_str)) logger.info("Configs:\n{}".format(cfg.dumps())) logger.info("Unique name: {}".format(cfg.unique_name)) logger.info("Get dataset ...") content_font = cfg.content_font n_comps = int(cfg.n_comps) trn_transform, val_transform = setup_transforms(cfg) env = load_lmdb(cfg.data_path) env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img']) data_meta = load_json(cfg.data_meta) dec_dict = load_json(cfg.dec_dict) if cfg.phase == "comb": get_trn_loader = get_comb_trn_loader get_cv_loaders = get_cv_comb_loaders Trainer = CombinedTrainer elif cfg.phase == "fact": get_trn_loader = get_fact_trn_loader get_cv_loaders = get_cv_fact_loaders Trainer = FactorizeTrainer else: raise ValueError(cfg.phase) trn_dset, trn_loader = get_trn_loader(env, env_get, cfg, data_meta["train"], dec_dict, trn_transform, num_workers=cfg.n_workers, shuffle=True) if is_main_worker(ddp_gpu): cv_loaders = get_cv_loaders(env, env_get, cfg, data_meta, dec_dict, val_transform, num_workers=cfg.n_workers, shuffle=False) else: cv_loaders = None logger.info("Build model ...") # generator g_kwargs = cfg.get("g_args", {}) g_cls = generator_dispatch() gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps) gen.cuda() gen.apply(weights_init(cfg.init)) if cfg.gan_w > 0.: d_kwargs = cfg.get("d_args", {}) disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs) disc.cuda() disc.apply(weights_init(cfg.init)) else: disc = None if cfg.ac_w > 0.: aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args) aux_clf.cuda() aux_clf.apply(weights_init(cfg.init)) else: aux_clf = None assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss" g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas) d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \ if disc is not None else None ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \ if aux_clf is not None else None st_step = 1 if args.resume: st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite) logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format( args.resume, st_step - 1, loss)) if cfg.overwrite: st_step = 1 else: pass evaluator = Evaluator(env, env_get, logger, writer, cfg.batch_size, val_transform, content_font, use_half=cfg.use_half ) trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer, logger, evaluator, cv_loaders, cfg) trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])