Exemplo n.º 1
0
def main():
    args = options()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    testset = RegistrationData("PCRNet",
                               ModelNet40Data(train=False, ),
                               is_testing=True)
    test_loader = DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )

    if not torch.cuda.is_available():
        args.device = "cpu"
    args.device = torch.device(args.device)

    # Create PointNet Model.
    ptnet = PointNet(emb_dims=args.emb_dims)
    model = iPCRNet(feature_model=ptnet)

    # Create sampler
    if args.sampler == "samplenet":
        sampler = SampleNet(
            args.num_out_points,
            args.bottleneck_size,
            args.projection_group_size,
            skip_projection=args.skip_projection,
            input_shape="bnc",
            output_shape="bnc",
        )
    elif args.sampler == "fps":
        sampler = FPSSampler(args.num_out_points,
                             permute=True,
                             input_shape="bnc",
                             output_shape="bnc")
    else:
        sampler = None

    model.sampler = sampler
    model = model.to(args.device)

    model = model.to(args.device)

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        model.load_state_dict(torch.load(args.pretrained, map_location="cpu"))
    model.to(args.device)

    test(args, model, test_loader)
Exemplo n.º 2
0
def init_inference():
    global device
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    elif args.model == 'simplenet':
        model = SimpleNet()
    else:
        raise NotImplementedError()
    model.eval()

    model.load_state_dict(torch.load(args.pretrained_model))
    model = model.cuda()
    x = torch.ones((1, 3, 240, 320)).cuda()
    from torch2trt import torch2trt
    #model_trt = torch2trt(model, [x], max_batch_size=100, fp16_mode=True)
    model_trt = torch2trt(model, [x], max_batch_size=100)
    torch.save(model_trt.state_dict(), args.trt_model)
Exemplo n.º 3
0
def init_inference():
    global model
    global device
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    elif args.model == 'simplenet':
        model = SimpleNet()
    else:
        raise NotImplementedError()
    model.eval()
    #model.load_state_dict(torch.load(args.pretrained_model))

    if args.trt_module:
        from torch2trt import TRTModule
        if args.trt_conversion:
            model.load_state_dict(torch.load(args.pretrained_model))
            model = model.cuda()
            x = torch.ones((1, 3, 240, 320)).cuda()
            from torch2trt import torch2trt
            model_trt = torch2trt(model, [x],
                                  max_batch_size=100,
                                  fp16_mode=True)
            #model_trt = torch2trt(model, [x], max_batch_size=100)
            torch.save(model_trt.state_dict(), args.trt_model)
            exit()
        model_trt = TRTModule()
        #model_trt.load_state_dict(torch.load('road_following_model_trt_half.pth'))
        model_trt.load_state_dict(torch.load(args.trt_model))
        model = model_trt.to(device)
    else:
        model.load_state_dict(torch.load(args.pretrained_model))
        model = model.to(device)
Exemplo n.º 4
0
def main():
    # Parse arguments.
    args = parse_args()

    # Set device.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Prepare dataset.
    np.random.seed(seed=0)
    image_dataframe = pd.read_csv(args.data_csv, engine='python', header=None)
    image_dataframe = image_dataframe.reindex(
        np.random.permutation(image_dataframe.index))
    test_num = int(len(image_dataframe) * 0.2)
    train_dataframe = image_dataframe[test_num:]
    test_dataframe = image_dataframe[:test_num]
    train_data = MyDataset(train_dataframe, transform=transforms.ToTensor())
    test_data = MyDataset(test_dataframe, transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=20,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=20)

    print('data set')
    # Set a model.
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    elif args.model == 'simplenet':
        model = SimpleNet()
    else:
        raise NotImplementedError()
    model.train()
    model = model.to(device)

    print('model set')
    # Set loss function and optimization function.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    print('optimizer set')

    # Train and test.
    print('Train starts')
    for epoch in range(args.n_epoch):
        # Train and test a model.
        train_acc, train_loss = train(model, device, train_loader, criterion,
                                      optimizer)

        # Output score.
        if (epoch % args.test_interval == 0):
            test_acc, test_loss = test(model, device, test_loader, criterion)

            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}'
            print(
                stdout_temp.format(epoch + 1, train_acc, train_loss, test_acc,
                                   test_loss))

        else:
            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}'  #, test acc: {:<8}, test loss: {:<8}'
            print(stdout_temp.format(epoch + 1, train_acc,
                                     train_loss))  #, test_acc, test_loss))

        # Save a model checkpoint.
        if (epoch % args.save_model_interval == 0
                or epoch + 1 == args.n_epoch):
            model_ckpt_path = args.model_ckpt_path_temp.format(
                args.dataset_name, args.model_name, epoch + 1)
            torch.save(model.state_dict(), model_ckpt_path)
            print('Saved a model checkpoint at {}'.format(model_ckpt_path))
            print('')
Exemplo n.º 5
0
def main():
    # Parse arguments.
    args = parse_args()

    # Set device.
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ROOT_DIR = ""
    imgDataset = MyDataset(args.data_csv,
                           ROOT_DIR,
                           transform=transforms.ToTensor())
    # Load dataset.
    train_data, test_data = train_test_split(imgDataset, test_size=0.2)
    pd.to_pickle(test_data, "test_data.pkl")
    del test_data
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=20,
                                               shuffle=True)

    print('data set')
    # Set a model.
    if args.model == 'resnet18':
        model = models.resnet18()
        model.fc = torch.nn.Linear(512, 3)
    elif args.model == 'samplenet':
        model = SampleNet()
    else:
        raise NotImplementedError()
    model.train()
    model = model.to(device)

    print('model set')
    # Set loss function and optimization function.
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    #optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    print('optimizer set')

    # Train and test.
    print('Train starts')
    for epoch in range(args.n_epoch):
        # Train and test a model.
        train_acc, train_loss = train(model, device, train_loader, criterion,
                                      optimizer)

        # Output score.
        if (epoch % args.test_interval == 0):
            pd.to_pickle(train_data, "train_data.pkl")
            del train_data

            test_data = pd.read_pickle("test_data.pkl")
            test_loader = torch.utils.data.DataLoader(test_data,
                                                      batch_size=20,
                                                      shuffle=True)
            del test_data
            test_acc, test_loss = test(model, device, test_loader, criterion)
            del test_loader

            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}, test acc: {:<8}, test loss: {:<8}'
            print(
                stdout_temp.format(epoch + 1, train_acc, train_loss, test_acc,
                                   test_loss))

            train_data = pd.read_pickle("train_data.pkl")
        else:
            stdout_temp = 'epoch: {:>3}, train acc: {:<8}, train loss: {:<8}'  #, test acc: {:<8}, test loss: {:<8}'
            print(stdout_temp.format(epoch + 1, train_acc,
                                     train_loss))  #, test_acc, test_loss))

        # Save a model checkpoint.
        if (epoch % args.save_model_interval == 0
                or epoch + 1 == args.n_epoch):
            model_ckpt_path = args.model_ckpt_path_temp.format(
                args.dataset_name, args.model_name, epoch + 1)
            torch.save(model.state_dict(), model_ckpt_path)
            print('Saved a model checkpoint at {}'.format(model_ckpt_path))
            print('')
Exemplo n.º 6
0
def main():
    parser = sputils.get_parser()
    args = options(parser)

    torch.backends.cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    boardio = SummaryWriter(log_dir="checkpoints/" + args.exp_name)
    _init_(args)

    textio = IOStream("checkpoints/" + args.exp_name + "/run.log")
    textio.cprint(str(args))

    # trainset = RegistrationData("PCRNet", ModelNet40Data(train=True, download=True))
    # testset = RegistrationData("PCRNet", ModelNet40Data(train=False, download=True))

    transforms = torchvision.transforms.Compose(
        [PointcloudToTensor(), OnUnitCube()])
    traindata = ModelNetCls(
        args.num_in_points,
        transforms=transforms,
        train=True,
        download=False,
        folder=args.datafolder,
    )
    testdata = ModelNetCls(
        args.num_in_points,
        transforms=transforms,
        train=False,
        download=False,
        folder=args.datafolder,
    )

    trainset = RegistrationData("PCRNet", traindata)
    testset = RegistrationData("PCRNet", testdata)

    train_loader = DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=args.workers,
    )
    test_loader = DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.workers,
    )

    if not torch.cuda.is_available():
        args.device = "cpu"
    args.device = torch.device(args.device)

    # Create PointNet Model.
    ptnet = PointNet(emb_dims=args.emb_dims)
    model = iPCRNet(feature_model=ptnet)

    if args.train_pcrnet:
        model.requires_grad_(True)
        model.train()
    else:
        model.requires_grad_(False)
        model.eval()

    # Create sampler
    if args.sampler == "samplenet":
        sampler = SampleNet(
            args.num_out_points,
            args.bottleneck_size,
            args.projection_group_size,
            skip_projection=args.skip_projection,
            input_shape="bnc",
            output_shape="bnc",
        )
        if args.train_samplenet:
            sampler.requires_grad_(True)
            sampler.train()
        else:
            sampler.requires_grad_(False)
            sampler.eval()
    else:
        sampler = None

    model.sampler = sampler
    model = model.to(args.device)

    checkpoint = None
    if args.resume:
        assert os.path.isfile(args.resume)
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint["model"])

    if args.pretrained:
        assert os.path.isfile(args.pretrained)
        missing_keys, unexpected_keys = model.load_state_dict(torch.load(
            args.pretrained, map_location="cpu"),
                                                              strict=False)

        if len(missing_keys) != 0:
            print(f"Found missing keys in checkpoint: {missing_keys}")
        if len(unexpected_keys) != 0:
            raise RuntimeError(
                f"Found missing keys in model: {unexpected_keys}")

        filtered_missing_keys = [
            x for x in missing_keys if not x.startswith("sampler")
        ]
        if len(filtered_missing_keys) != 0:
            raise RuntimeError(
                f"Found missing keys in checkpoint: {filtered_missing_keys}")

    model.to(args.device)

    if args.eval:
        test(args, model, test_loader, textio)
    else:
        train(args, model, train_loader, test_loader, boardio, textio,
              checkpoint)