Exemplo n.º 1
0
def frames_foler_demo():
    args = parse_args()
    config = build_config(args.config)
    config.args = args
    model = build_backbone(config.model).cuda()

    ckpt = torch.load(args.ckpt, map_location="cuda:0")

    model_ckpt = ckpt["model"]
    model_ckpt = {k[7:]: v for k, v in model_ckpt.items()}
    model.load_state_dict(model_ckpt)

    data_config = edict({
        "root_input":
        "/home/cmd/projects/simdeblur/datasets/DVD/qualitative_datasets/alley/input",
        "num_frames": 5,
        "overlapping": True,
        "sampling": "n_c"
    })
    frames_data = FramesFolder(data_config)
    frames_dataloader = torch.utils.data.DataLoader(frames_data, 1)

    model.eval()
    with torch.no_grad():
        for i, batch_data in enumerate(frames_dataloader):
            out = model(batch_data["input_frames"].cuda())
            print(batch_data["gt_names"], out.shape)
Exemplo n.º 2
0
 def build_model(cls, cfg):
     """
     build a model
     """
     # TODO change the build backbone to build model
     model = build_backbone(cfg.model)
     if cfg.args.gpus > 1:
         rank = cfg.args.local_rank
         model = nn.parallel.DistributedDataParallel(model.cuda(),
                                                     device_ids=[rank],
                                                     output_device=rank)
     if cfg.args.local_rank == 0:
         logger = logging.getLogger("simdeblur")
         logger.info("Model:\n{}".format(model))
     return model