Exemplo n.º 1
0
def init_inference():
    global model
    global device
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    elif args.model == 'simplenet':
        model = SimpleNet()
    else:
        raise NotImplementedError()
    model.eval()
    #model.load_state_dict(torch.load(args.pretrained_model))

    if args.trt_module:
        from torch2trt import TRTModule
        if args.trt_conversion:
            model.load_state_dict(torch.load(args.pretrained_model))
            model = model.cuda()
            x = torch.ones((1, 3, 240, 320)).cuda()
            from torch2trt import torch2trt
            model_trt = torch2trt(model, [x],
                                  max_batch_size=100,
                                  fp16_mode=True)
            #model_trt = torch2trt(model, [x], max_batch_size=100)
            torch.save(model_trt.state_dict(), args.trt_model)
            exit()
        model_trt = TRTModule()
        #model_trt.load_state_dict(torch.load('road_following_model_trt_half.pth'))
        model_trt.load_state_dict(torch.load(args.trt_model))
        model = model_trt.to(device)
    else:
        model.load_state_dict(torch.load(args.pretrained_model))
        model = model.to(device)
Exemplo n.º 2
0
def init_inference():
    global device
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    elif args.model == 'simplenet':
        model = SimpleNet()
    else:
        raise NotImplementedError()
    model.eval()

    model.load_state_dict(torch.load(args.pretrained_model))
    model = model.cuda()
    x = torch.ones((1, 3, 240, 320)).cuda()
    from torch2trt import torch2trt
    #model_trt = torch2trt(model, [x], max_batch_size=100, fp16_mode=True)
    model_trt = torch2trt(model, [x], max_batch_size=100)
    torch.save(model_trt.state_dict(), args.trt_model)
Exemplo n.º 3
0
def main():
    parser = sputils.get_parser()
    args = options(parser)

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir="checkpoints/" + args.exp_name)
    _init_(args)

    textio = IOStream("checkpoints/" + args.exp_name + "/run.log")
    textio.cprint(str(args))

    # trainset = RegistrationData("PCRNet", ModelNet40Data(train=True, download=True))
    # testset = RegistrationData("PCRNet", ModelNet40Data(train=False, download=True))

    transforms = torchvision.transforms.Compose(
        [PointcloudToTensor(), OnUnitCube()])
    traindata = ModelNetCls(
        args.num_in_points,
        transforms=transforms,
        train=True,
        download=False,
        folder=args.datafolder,
    )
    testdata = ModelNetCls(
        args.num_in_points,
        transforms=transforms,
        train=False,
        download=False,
        folder=args.datafolder,
    )

    trainset = RegistrationData("PCRNet", traindata)
    testset = RegistrationData("PCRNet", testdata)

    train_loader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    test_loader = DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )

    if not torch.cuda.is_available():
        args.device = "cpu"
    args.device = torch.device(args.device)

    # Create PointNet Model.
    ptnet = PointNet(emb_dims=args.emb_dims)
    model = iPCRNet(feature_model=ptnet)

    if args.train_pcrnet:
        model.requires_grad_(True)
        model.train()
    else:
        model.requires_grad_(False)
        model.eval()

    # Create sampler
    if args.sampler == "samplenet":
        sampler = SampleNet(
            args.num_out_points,
            args.bottleneck_size,
            args.projection_group_size,
            skip_projection=args.skip_projection,
            input_shape="bnc",
            output_shape="bnc",
        )
        if args.train_samplenet:
            sampler.requires_grad_(True)
            sampler.train()
        else:
            sampler.requires_grad_(False)
            sampler.eval()
    else:
        sampler = None

    model.sampler = sampler
    model = model.to(args.device)

    checkpoint = None
    if args.resume:
        assert os.path.isfile(args.resume)
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["model"])

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        missing_keys, unexpected_keys = model.load_state_dict(torch.load(
            args.pretrained, map_location="cpu"),
                                                              strict=False)

        if len(missing_keys) != 0:
            print(f"Found missing keys in checkpoint: {missing_keys}")
        if len(unexpected_keys) != 0:
            raise RuntimeError(
                f"Found missing keys in model: {unexpected_keys}")

        filtered_missing_keys = [
            x for x in missing_keys if not x.startswith("sampler")
        ]
        if len(filtered_missing_keys) != 0:
            raise RuntimeError(
                f"Found missing keys in checkpoint: {filtered_missing_keys}")

    model.to(args.device)

    if args.eval:
        test(args, model, test_loader, textio)
    else:
        train(args, model, train_loader, test_loader, boardio, textio,
              checkpoint)