Пример #1
0
def main(args):
    cfg = load_cfg(args.cfg)
    distiller = Distiller(cfg)
    if args.ckpt is not None:
        ckpt = model_zoo(args.ckpt)
        load_weights(distiller, ckpt["state_dict"])
    logger = build_logger(cfg.logger)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        filepath=os.getcwd() if args.checkpoint_dir is None else args.checkpoint_dir,
        save_top_k=True,
        save_last=True,
        verbose=True,
        monitor=cfg.trainer.monitor,
        mode=cfg.trainer.monitor_mode,
        prefix=''
    )
    trainer = pl.Trainer(
        gpus=args.gpus,
        max_epochs=cfg.trainer.max_epochs,
        accumulate_grad_batches=args.grad_batches,
        distributed_backend=args.distributed_backend,
        checkpoint_callback=checkpoint_callback,
        val_check_interval=args.val_check_interval,
        logger=logger
    )
    if args.to_onnx is None:
        trainer.fit(distiller)
    else:
        distiller.to_onnx(args.to_onnx)
Пример #2
0
def main(args):
    cfg = load_cfg(args.cfg)
    distiller = Distiller(cfg)
    if args.ckpt is not None:
        ckpt = model_zoo(args.ckpt)
        load_weights(distiller, ckpt["state_dict"])

    while True:
        var = torch.randn(1, distiller.mapping_net.style_dim)
        img_s = distiller(var, truncated=args.truncated)
        cv2.imshow("demo", tensor_to_img(img_s[0].cpu()))
        key = chr(cv2.waitKey() & 255)
        if key == 'q':
            break
def main(args):
    cfg = load_cfg(args.cfg)
    distiller = Distiller(cfg)
    if args.ckpt is not None:
        ckpt = model_zoo(args.ckpt)
        load_weights(distiller, ckpt["state_dict"])

    for i in range(args.n_output):
        var = torch.randn(1, distiller.mapping_net.style_dim)
        img_s = distiller(var, truncated=args.truncated)
        path = args.output_path + "pic" + str(i) + ".jpg"
        cv2.imwrite(path, tensor_to_img(img_s[0].cpu()))
        key = chr(cv2.waitKey() & 255)
        if key == 'q':
            break
Пример #4
0
def main(args):
    cfg = load_cfg(args.cfg)
    distiller = Distiller(cfg)
    if args.ckpt is not None:
        ckpt = model_zoo(args.ckpt)
        load_weights(distiller, ckpt["state_dict"])

    distiller = distiller.to(args.device)
    for i in tqdm(range(args.n_batches)):
        var = torch.randn(args.batch_size,
                          distiller.mapping_net.style_dim).to(args.device)
        img_s = distiller(var, truncated=args.truncated)
        for j in range(img_s.size(0)):
            cv2.imwrite(
                os.path.join(args.output_path, f"{i*args.batch_size + j}.png"),
                tensor_to_img(img_s[j].cpu()))
Пример #5
0
def main(args):
    cfg = load_cfg(args.cfg)
    distiller = Distiller(cfg)
    if args.ckpt is not None:
        ckpt = model_zoo(args.ckpt)
        load_weights(distiller, ckpt["state_dict"])

    while True:
        var = torch.randn(1, distiller.mapping_net.style_dim)
        img_s, img_t = distiller.simultaneous_forward(var, truncated=args.truncated)
        img = np.hstack([
            tensor_to_img(img_t[0].cpu()),
            tensor_to_img(img_s[0].cpu())
        ])
        cv2.imshow("compare", img)
        key = chr(cv2.waitKey() & 255)
        if key == 'q':
            break
Пример #6
0
def main(args):
    cfg = load_cfg(args.cfg)
    cfg.distributed_backend = args.distributed_backend
    nmt_trainer = NMTTrainer(cfg)
    if args.ckpt is not None:
        nmt_trainer.load_state_dict(torch.load(args.ckpt, map_location="cpu")["state_dict"])

    if args.cli_mode:
        nmt_trainer.cli_mode()
    elif args.to_onnx:
        nmt_trainer.to_onnx(args.onnx_path, args.onnx_denominator)
    else:
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            filepath=args.output_dir,
            save_top_k=True,
            verbose=True,
            monitor='bleu',
            mode='max'
        )
        if args.log_path is not None:
            logger = JSONLogger(args.log_path)
        else:
            logger = None
        trainer = pl.Trainer(
            gpus=args.gpus,
            logger=logger,
            max_epochs=cfg.trainer.max_epochs,
            accumulate_grad_batches=args.grad_batches,
            distributed_backend=args.distributed_backend,
            checkpoint_callback=checkpoint_callback,
            val_check_interval=args.val_check_interval,
        )
        if not args.eval:
            trainer.fit(nmt_trainer)
        else:
            trainer.test(nmt_trainer)