Esempio n. 1
0
def main():
    start_time = time.monotonic()

    # init distributed training
    args, cfg = parge_config()
    dist = init_dist(cfg)
    synchronize()

    # init logging file
    logger = Logger(cfg.work_dir / "log_test.txt")
    sys.stdout = logger
    print("==========\nArgs:{}\n==========".format(args))
    log_config_to_file(cfg)

    # build model
    model = build_gan_model(cfg, only_generator=True)['G']
    model.cuda()

    if dist:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[cfg.gpu],
            output_device=cfg.gpu,
            find_unused_parameters=True,
        )
    elif cfg.total_gpus > 1:
        model = torch.nn.DataParallel(model)

    # load checkpoint
    state_dict = load_checkpoint(args.resume)
    copy_state_dict(state_dict["state_dict"], model)

    # load data_loader
    test_loader, _ = build_val_dataloader(cfg,
                                          for_clustering=True,
                                          all_datasets=True)
    # print(len(test_loader[0]))
    # return
    # start testing
    infer_gan(
        cfg,
        model,
        test_loader[0],  # source dataset
        # dataset_name=list(cfg.TRAIN.datasets.keys())[0]
        dataset_name=cfg.TRAIN.data_names[0])

    # print time
    end_time = time.monotonic()
    print("Total running time: ", timedelta(seconds=end_time - start_time))
Esempio n. 2
0
def main():
    start_time = time.monotonic()

    # init distributed training
    args, cfg = parge_config()
    dist = init_dist(cfg)
    set_random_seed(cfg.TRAIN.seed, cfg.TRAIN.deterministic)
    synchronize()

    # init logging file
    logger = Logger(cfg.work_dir / 'log.txt', debug=False)
    sys.stdout = logger
    print("==========\nArgs:{}\n==========".format(args))
    log_config_to_file(cfg)

    # build train loader
    train_loader, _ = build_train_dataloader(cfg, joint=False)
    # build model
    model = build_gan_model(cfg)
    for key in model.keys():
        model[key].cuda()

    if dist:
        ddp_cfg = {
            "device_ids": [cfg.gpu],
            "output_device": cfg.gpu,
            "find_unused_parameters": True,
        }
        for key in model.keys():
            model[key] = torch.nn.parallel.DistributedDataParallel(
                model[key], **ddp_cfg)
    elif cfg.total_gpus > 1:
        for key in model.keys():
            model[key] = torch.nn.DataParallel(model[key])

    # build optimizer
    optimizer = {}
    optimizer['G'] = build_optimizer([model['G_A'], model['G_B']],
                                     **cfg.TRAIN.OPTIM)
    optimizer['D'] = build_optimizer([model['D_A'], model['D_B']],
                                     **cfg.TRAIN.OPTIM)

    # build lr_scheduler
    if cfg.TRAIN.SCHEDULER.lr_scheduler is not None:
        lr_scheduler = [build_lr_scheduler(optimizer[key], **cfg.TRAIN.SCHEDULER) \
                        for key in optimizer.keys()]
    else:
        lr_scheduler = None

    # build loss functions
    criterions = build_loss(cfg.TRAIN.LOSS, cuda=True)

    # build runner
    runner = GANBaseRunner(cfg,
                           model,
                           optimizer,
                           criterions,
                           train_loader,
                           lr_scheduler=lr_scheduler,
                           meter_formats={"Time": ":.3f"})

    # resume
    if args.resume_from:
        runner.resume(args.resume_from)

    # start training
    runner.run()

    # load the latest model
    # runner.resume(cfg.work_dir)

    # final inference
    test_loader, _ = build_val_dataloader(cfg,
                                          for_clustering=True,
                                          all_datasets=True)
    # source to target
    infer_gan(cfg,
              model['G_A'],
              test_loader[0],
              dataset_name=list(cfg.TRAIN.datasets.keys())[0])
    # target to source
    infer_gan(cfg,
              model['G_B'],
              test_loader[1],
              dataset_name=list(cfg.TRAIN.datasets.keys())[1])

    # print time
    end_time = time.monotonic()
    print("Total running time: ", timedelta(seconds=end_time - start_time))