Ejemplo n.º 1
0
 def setUp(self):
     self.trainset = config.train_name
     self.valset = config.val_name
     self.testset = config.test_name
     self.tag_padding = -1
     self.params = Params(config.datasets_params_file)
     self.params.embedding_dim = 50
     self.params.lstm_hidden_dim = 50
     self.loader = DataLoader(config.data_dir, self.params)
     self.logger = Logger.get()
Ejemplo n.º 2
0
 def __init__(self, datasets_params_file, data_factor=1.0, 
              train_factor=0.7, val_factor=0.15, test_factor=0.15, 
              train_name='train', val_name='val', 
              test_name='test'):
     """
     Args:
         datasets_params_file: the file will be used to save datasets
             parameters.
     """
     self.datasets_params_file = datasets_params_file
     self.data_factor = data_factor
     self.train_factor = train_factor
     self.val_factor = val_factor
     self.test_factor = test_factor
     self.train_name = train_name
     self.val_name = val_name
     self.test_name = test_name
     self.samples = []
     self.logger = Logger.get()
Ejemplo n.º 3
0
def evaluate(model, dateset, loss_fn, metrics):
    # set model to evaluation mode
    model.eval()

    logger = Logger.get()
    running_avg = RunningAvg()
    for batch in dateset:
        inputs, targets = batch
        stat = {}
        with torch.no_grad():
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            stat['loss'] = loss.item()
            for name, metric in metrics.items():
                stat[name] = metric(outputs, targets).item()
        running_avg.step(stat)
    metrics_mean = running_avg()
    
    logger.info("- Evaluation metrics:")
    for name, value in metrics_mean.items():
        logger.info('    * {}: {:05.3f}'.format(name, value))

    return metrics_mean
Ejemplo n.º 4
0
def eval_ckpt():
    import argparse
    from models import generator_dispatch
    from sconf import Config
    from train import setup_transforms
    from datasets import load_json, get_fact_test_loader

    logger = Logger.get()

    parser = argparse.ArgumentParser()
    parser.add_argument("config_paths", nargs="+", help="path to config.yaml")
    parser.add_argument("--weight", help="path to weight to evaluate.pth")
    parser.add_argument("--img_dir", help="path to save images for evaluation")
    parser.add_argument(
        "--test_meta",
        help=
        "path to metafile: contains (font, chars (in unicode)) to generate and reference chars (in unicode)"
    )
    args, left_argv = parser.parse_known_args()

    cfg = Config(*args.config_paths, default="cfgs/defaults.yaml")
    cfg.argv_update(left_argv)

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)
    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(
        read_data_from_lmdb(env, f'{x}_{y}')['img'])

    test_meta = load_json(args.test_meta)
    dec_dict = load_json(cfg.dec_dict)

    g_kwargs = cfg.get('g_args', {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg['C'], 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()

    weight = torch.load(args.weight)
    if "generator_ema" in weight:
        weight = weight["generator_ema"]
    gen.load_state_dict(weight)
    logger.info(f"Resumed checkpoint from {args.weight}")
    writer = None

    evaluator = Evaluator(env, env_get, logger, writer, cfg["batch_size"],
                          val_transform, content_font)

    img_dir = Path(args.img_dir)
    ref_unis = test_meta["ref_unis"]
    gen_unis = test_meta["gen_unis"]
    gen_fonts = test_meta["gen_fonts"]
    target_dict = {f: gen_unis for f in gen_fonts}

    loader = get_fact_test_loader(env,
                                  env_get,
                                  target_dict,
                                  ref_unis,
                                  cfg,
                                  None,
                                  dec_dict,
                                  val_transform,
                                  ret_targets=False,
                                  num_workers=cfg.n_workers,
                                  shuffle=False)[1]

    logger.info("Save CV results to {} ...".format(img_dir))
    evaluator.save_each_imgs(gen,
                             loader,
                             save_dir=img_dir,
                             phase="fact",
                             reduction='mean')
Ejemplo n.º 5
0
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "logs" / "{}.log".format(cfg.unique_name)
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.6
    writer_path = cfg.work_dir / "runs" / cfg.unique_name
    image_path = cfg.work_dir / "images" / cfg.unique_name
    writer = utils.TBDiskWriter(writer_path, image_path, scale=image_scale)

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))
        logger.info("Unique name: {}".format(cfg.unique_name))

    logger.info("Get dataset ...")

    content_font = cfg.content_font
    n_comps = int(cfg.n_comps)

    trn_transform, val_transform = setup_transforms(cfg)

    env = load_lmdb(cfg.data_path)
    env_get = lambda env, x, y, transform: transform(read_data_from_lmdb(env, f'{x}_{y}')['img'])

    data_meta = load_json(cfg.data_meta)
    dec_dict = load_json(cfg.dec_dict)

    if cfg.phase == "comb":
        get_trn_loader = get_comb_trn_loader
        get_cv_loaders = get_cv_comb_loaders
        Trainer = CombinedTrainer

    elif cfg.phase == "fact":
        get_trn_loader = get_fact_trn_loader
        get_cv_loaders = get_cv_fact_loaders
        Trainer = FactorizeTrainer

    else:
        raise ValueError(cfg.phase)

    trn_dset, trn_loader = get_trn_loader(env,
                                          env_get,
                                          cfg,
                                          data_meta["train"],
                                          dec_dict,
                                          trn_transform,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    if is_main_worker(ddp_gpu):
        cv_loaders = get_cv_loaders(env,
                                    env_get,
                                    cfg,
                                    data_meta,
                                    dec_dict,
                                    val_transform,
                                    num_workers=cfg.n_workers,
                                    shuffle=False)
    else:
        cv_loaders = None

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    g_cls = generator_dispatch()
    gen = g_cls(1, cfg.C, 1, **g_kwargs, n_comps=n_comps)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    if cfg.gan_w > 0.:
        d_kwargs = cfg.get("d_args", {})
        disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_unis, **d_kwargs)
        disc.cuda()
        disc.apply(weights_init(cfg.init))
    else:
        disc = None

    if cfg.ac_w > 0.:
        aux_clf = aux_clf_builder(gen.mem_shape, n_comps, **cfg.ac_args)
        aux_clf.cuda()
        aux_clf.apply(weights_init(cfg.init))
    else:
        aux_clf = None
        assert cfg.ac_gen_w == 0., "ac_gen loss is only available with ac loss"

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas) \
        if disc is not None else None
    ac_optim = optim.Adam(aux_clf.parameters(), lr=cfg.ac_lr, betas=cfg.adam_betas) \
        if aux_clf is not None else None

    st_step = 1
    if args.resume:
        st_step, loss = load_checkpoint(args.resume, gen, disc, aux_clf, g_optim, d_optim, ac_optim, cfg.overwrite)
        logger.info("Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
            args.resume, st_step - 1, loss))
        if cfg.overwrite:
            st_step = 1
        else:
            pass

    evaluator = Evaluator(env,
                          env_get,
                          logger,
                          writer,
                          cfg.batch_size,
                          val_transform,
                          content_font,
                          use_half=cfg.use_half
                          )

    trainer = Trainer(gen, disc, g_optim, d_optim,
                      aux_clf, ac_optim,
                      writer, logger,
                      evaluator, cv_loaders,
                      cfg)

    trainer.train(trn_loader, st_step, cfg[f"{cfg.phase}_iter"])
Ejemplo n.º 6
0
def train(args, cfg, ddp_gpu=-1):
    cfg.gpu = ddp_gpu
    torch.cuda.set_device(ddp_gpu)
    cudnn.benchmark = True

    logger_path = cfg.work_dir / "log.log"
    logger = Logger.get(file_path=logger_path, level="info", colorize=True)

    image_scale = 0.5
    image_path = cfg.work_dir / "images"
    writer = utils.DiskWriter(image_path, scale=image_scale)
    cfg.tb_freq = -1

    args_str = dump_args(args)
    if is_main_worker(ddp_gpu):
        logger.info("Run Argv:\n> {}".format(" ".join(sys.argv)))
        logger.info("Args:\n{}".format(args_str))
        logger.info("Configs:\n{}".format(cfg.dumps()))

    logger.info("Get dataset ...")

    trn_transform, val_transform = setup_transforms(cfg)

    primals = json.load(open(cfg.primals))
    decomposition = json.load(open(cfg.decomposition))
    n_comps = len(primals)

    trn_dset, trn_loader = get_trn_loader(cfg.dset.train,
                                          primals,
                                          decomposition,
                                          trn_transform,
                                          use_ddp=cfg.use_ddp,
                                          batch_size=cfg.batch_size,
                                          num_workers=cfg.n_workers,
                                          shuffle=True)

    test_dset, test_loader = get_val_loader(cfg.dset.val,
                                            val_transform,
                                            batch_size=cfg.batch_size,
                                            num_workers=cfg.n_workers,
                                            shuffle=False)

    logger.info("Build model ...")
    # generator
    g_kwargs = cfg.get("g_args", {})
    gen = Generator(1, cfg.C, 1, **g_kwargs)
    gen.cuda()
    gen.apply(weights_init(cfg.init))

    d_kwargs = cfg.get("d_args", {})
    disc = disc_builder(cfg.C, trn_dset.n_fonts, trn_dset.n_chars, **d_kwargs)
    disc.cuda()
    disc.apply(weights_init(cfg.init))

    aux_clf = aux_clf_builder(gen.feat_shape["last"], trn_dset.n_fonts,
                              n_comps, **cfg.ac_args)
    aux_clf.cuda()
    aux_clf.apply(weights_init(cfg.init))

    g_optim = optim.Adam(gen.parameters(), lr=cfg.g_lr, betas=cfg.adam_betas)
    d_optim = optim.Adam(disc.parameters(), lr=cfg.d_lr, betas=cfg.adam_betas)
    ac_optim = optim.Adam(aux_clf.parameters(),
                          lr=cfg.ac_lr,
                          betas=cfg.adam_betas)

    st_step = 0
    if cfg.resume:
        st_step, loss = load_checkpoint(cfg.resume, gen, disc, aux_clf,
                                        g_optim, d_optim, ac_optim,
                                        cfg.force_resume)
        logger.info(
            "Resumed checkpoint from {} (Step {}, Loss {:7.3f})".format(
                cfg.resume, st_step, loss))

    evaluator = Evaluator(writer)

    trainer = FactTrainer(gen, disc, g_optim, d_optim, aux_clf, ac_optim,
                          writer, logger, evaluator, test_loader, cfg)

    trainer.train(trn_loader, st_step, cfg.max_iter)