Exemplo n.º 1
0
def test_dump_args():
    parser = argparse.ArgumentParser("TEST")
    parser.add_argument("a")
    parser.add_argument("b")
    parser.add_argument("--beta")
    parser.add_argument("B")
    parser.add_argument("AA")
    parser.add_argument("--alpha")

    args = parser.parse_args(args="1 2 3 4 --beta 5 --alpha 6".split())
    dumps = dump_args(args)

    assert dumps == "\n".join([
        "a     = 1", "b     = 2", "beta  = 5", "B     = 3", "AA    = 4",
        "alpha = 6"
    ])
Exemplo n.º 2
0
def main():
    ############################
    # argument setup
    ############################
    args, cfg = setup_args_and_config()

    if args.show:
        print("### Run Argv:\n> {}".format(' '.join(sys.argv)))
        print("### Run Arguments:")
        s = dump_args(args)
        print(s + '\n')
        print("### Configs:")
        print(cfg.dumps())
        sys.exit()

    timestamp = utils.timestamp()
    unique_name = "{}_{}".format(timestamp, args.name)
    cfg['unique_name'] = unique_name  # for save directory
    cfg['name'] = args.name

    utils.makedirs('logs')
    utils.makedirs(Path('checkpoints', unique_name))

    # logger
    logger_path = Path('logs', f"{unique_name}.log")
    logger = Logger.get(file_path=logger_path,
                        level=args.log_lv,
                        colorize=True)

    # writer
    image_scale = 0.6
    writer_path = Path('runs', unique_name)
    if args.tb_image:
        writer = utils.TBWriter(writer_path, scale=image_scale)
    else:
        image_path = Path('images', unique_name)
        writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    # log default informations
    args_str = dump_args(args)
    logger.info("Run Argv:\n> {}".format(' '.join(sys.argv)))
    logger.info("Args:\n{}".format(args_str))
    logger.info("Configs:\n{}".format(cfg.dumps()))
    logger.info("Unique name: {}".format(unique_name))

    # seed
    np.random.seed(cfg['seed'])
    torch.manual_seed(cfg['seed'])
    random.seed(cfg['seed'])

    if args.deterministic:
        #  https://discuss.pytorch.org/t/how-to-get-deterministic-behavior/18177/16
        #  https://pytorch.org/docs/stable/notes/randomness.html
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        cfg['n_workers'] = 0
        logger.info("#" * 80)
        logger.info("# Deterministic option is activated !")
        logger.info("#" * 80)
    else:
        torch.backends.cudnn.benchmark = True

    ############################
    # setup dataset & loader
    ############################
    logger.info("Get dataset ...")

    # setup language dependent values
    content_font, n_comp_types, n_comps = setup_language_dependent(cfg)

    # setup transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

    # setup data
    hdf5_data, meta = setup_data(cfg, transform)

    # setup dataset
    trn_dset, loader = get_dset_loader(hdf5_data,
                                       meta['train']['fonts'],
                                       meta['train']['chars'],
                                       transform,
                                       True,
                                       cfg,
                                       content_font=content_font)

    logger.info("### Training dataset ###")
    logger.info("# of avail fonts = {}".format(trn_dset.n_fonts))
    logger.info(f"Total {len(loader)} iterations per epochs")
    logger.info("# of avail items = {}".format(trn_dset.n_avails))
    logger.info(f"#fonts = {trn_dset.n_fonts}, #chars = {trn_dset.n_chars}")

    val_loaders = setup_cv_dset_loader(hdf5_data, meta, transform,
                                       n_comp_types, content_font, cfg)
    sfuc_loader = val_loaders['SeenFonts-UnseenChars']
    sfuc_dset = sfuc_loader.dataset
    ufsc_loader = val_loaders['UnseenFonts-SeenChars']
    ufsc_dset = ufsc_loader.dataset
    ufuc_loader = val_loaders['UnseenFonts-UnseenChars']
    ufuc_dset = ufuc_loader.dataset

    logger.info("### Cross-validation datasets ###")
    logger.info("Seen fonts, Unseen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(sfuc_dset), len(sfuc_dset.fonts), len(sfuc_dset.chars),
                    len(sfuc_loader)))
    logger.info("Unseen fonts, Seen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(ufsc_dset), len(ufsc_dset.fonts), len(ufsc_dset.chars),
                    len(ufsc_loader)))
    logger.info("Unseen fonts, Unseen chars | "
                "#items = {}, #fonts = {}, #chars = {}, #steps = {}".format(
                    len(ufuc_dset), len(ufuc_dset.fonts), len(ufuc_dset.chars),
                    len(ufuc_loader)))

    ############################
    # build model
    ############################
    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get('g_args', {})
    gen = MACore(1,
                 cfg['C'],
                 1,
                 **g_kwargs,
                 n_comps=n_comps,
                 n_comp_types=n_comp_types,
                 language=cfg['language'])
    gen.cuda()
    gen.apply(weights_init(cfg['init']))

    d_kwargs = cfg.get('d_args', {})
    disc = Discriminator(cfg['C'], trn_dset.n_fonts, trn_dset.n_chars,
                         **d_kwargs)
    disc.cuda()
    disc.apply(weights_init(cfg['init']))

    if cfg['ac_w'] > 0.:
        C = gen.mem_shape[0]
        aux_clf = AuxClassifier(C, n_comps, **cfg['ac_args'])
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg['init']))
    else:
        aux_clf = None
        assert cfg[
            'ac_gen_w'] == 0., "ac_gen loss is only available with ac loss"

    # setup optimizer
    g_optim = optim.Adam(gen.parameters(),
                         lr=cfg['g_lr'],
                         betas=cfg['adam_betas'])
    d_optim = optim.Adam(disc.parameters(),
                         lr=cfg['d_lr'],
                         betas=cfg['adam_betas'])
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg['g_lr'], betas=cfg['adam_betas']) \
               if aux_clf is not None else None

    # resume checkpoint
    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf,
                                        g_optim, d_optim, ac_optim)
        logger.info(
            "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
                args.resume, st_step - 1, loss))
    if args.finetune:
        load_gen_checkpoint(args.finetune, gen)

    ############################
    # setup validation
    ############################
    evaluator = Evaluator(hdf5_data,
                          trn_dset.avails,
                          logger,
                          writer,
                          cfg['batch_size'],
                          content_font=content_font,
                          transform=transform,
                          language=cfg['language'],
                          val_loaders=val_loaders,
                          meta=meta)
    if args.debug:
        evaluator.n_cv_batches = 10
        logger.info("Change CV batches to 10 for debugging")

    ############################
    # start training
    ############################
    trainer = Trainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim, writer,
                      logger, evaluator, cfg)
    trainer.train(loader, st_step)
Exemplo n.º 3
0
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "log.log"
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.5
    image_path = cfg.work_dir / "images"
    writer = utils.DiskWriter(image_path, scale=image_scale)
    cfg.tb_freq = -1

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))

    logger.info("Get dataset ...")

    trn_transform, val_transform = setup_transforms(cfg)

    primals = json.load(open(cfg.primals))
    decomposition = json.load(open(cfg.decomposition))
    n_comps = len(primals)

    trn_dset, trn_loader = get_trn_loader(cfg.dset.train,
                                          primals,
                                          decomposition,
                                          trn_transform,
                                          use_ddp=cfg.use_ddp,
                                          batch_size=cfg.batch_size,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    test_dset, test_loader = get_val_loader(cfg.dset.val,
                                            val_transform,
                                            batch_size=cfg.batch_size,
                                            num_workers=cfg.n_workers,
                                            shuffle=False)

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    gen = Generator(1, cfg.C, 1, **g_kwargs)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    d_kwargs = cfg.get("d_args", {})
    disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_chars, **d_kwargs)
    disc.cuda()
    disc.apply(weights_init(cfg.init))

    aux_clf = aux_clf_builder(gen.feat_shape["last"], trn_dset.n_fonts,
                              n_comps, **cfg.ac_args)
    aux_clf.cuda()
    aux_clf.apply(weights_init(cfg.init))

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas)
    ac_optim = optim.Adam(aux_clf.parameters(),
                          lr=cfg.ac_lr,
                          betas=cfg.adam_betas)

    st_step = 0
    if cfg.resume:
        st_step, loss, g_optim, d_optim, ac_optim = load_checkpoint(
            cfg.resume, gen, disc, aux_clf, cfg)
        logger.info(
            "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
                cfg.resume, st_step, loss))

    evaluator = Evaluator(writer)

    trainer = FactTrainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim,
                          writer, logger, evaluator, test_loader, cfg)

    trainer.train(trn_loader, st_step, cfg.max_iter)
Exemplo n.º 4
0
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name)
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.6
    writer_path = cfg.work_dir / "runs" / cfg.unique_name
    image_path = cfg.work_dir / "images" / cfg.unique_name
    writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))
        logger.info("Unique name: {}".format(cfg.unique_name))

    logger.info("Get dataset ...")

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)

    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img'])

    data_meta = load_json(cfg.data_meta)
    dec_dict = load_json(cfg.dec_dict)

    if cfg.phase == "comb":
        get_trn_loader = get_comb_trn_loader
        get_cv_loaders = get_cv_comb_loaders
        Trainer = CombinedTrainer

    elif cfg.phase == "fact":
        get_trn_loader = get_fact_trn_loader
        get_cv_loaders = get_cv_fact_loaders
        Trainer = FactorizeTrainer

    else:
        raise ValueError(cfg.phase)

    trn_dset, trn_loader = get_trn_loader(env,
                                          env_get,
                                          cfg,
                                          data_meta["train"],
                                          dec_dict,
                                          trn_transform,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    if is_main_worker(ddp_gpu):
        cv_loaders = get_cv_loaders(env,
                                    env_get,
                                    cfg,
                                    data_meta,
                                    dec_dict,
                                    val_transform,
                                    num_workers=cfg.n_workers,
                                    shuffle=False)
    else:
        cv_loaders = None

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    if cfg.gan_w > 0.:
        d_kwargs = cfg.get("d_args", {})
        disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs)
        disc.cuda()
        disc.apply(weights_init(cfg.init))
    else:
        disc = None

    if cfg.ac_w > 0.:
        aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args)
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg.init))
    else:
        aux_clf = None
        assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss"

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \
        if disc is not None else None
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \
        if aux_clf is not None else None

    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite)
        logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
            args.resume, st_step - 1, loss))
        if cfg.overwrite:
            st_step = 1
        else:
            pass

    evaluator = Evaluator(env,
                          env_get,
                          logger,
                          writer,
                          cfg.batch_size,
                          val_transform,
                          content_font,
                          use_half=cfg.use_half
                          )

    trainer = Trainer(gen, disc, g_optim, d_optim,
                      aux_clf, ac_optim,
                      writer, logger,
                      evaluator, cv_loaders,
                      cfg)

    trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])