예제 #1
0
파일: train.py 프로젝트: tom-roddick/oft
def main():

    # Parse command line arguments
    args = parse_args()

    # Create experiment
    summary = _make_experiment(args)

    # Create datasets
    train_data = KittiObjectDataset(
        args.root, 'train', args.grid_size, args.grid_res, args.yoffset)
    val_data = KittiObjectDataset(
        args.root, 'val', args.grid_size, args.grid_res, args.yoffset)
    
    # Apply data augmentation
    train_data = oft.AugmentedObjectDataset(
        train_data, args.train_image_size, args.train_grid_size, 
        jitter=args.grid_jitter)

    # Create dataloaders
    train_loader = DataLoader(train_data, args.batch_size, shuffle=True, 
        num_workers=args.workers, collate_fn=oft.utils.collate)
    val_loader = DataLoader(val_data, args.batch_size, shuffle=False, 
        num_workers=args.workers,collate_fn=oft.utils.collate)

    # Build model
    model = OftNet(num_classes=1, frontend=args.frontend, 
                   topdown_layers=args.topdown, grid_res=args.grid_res, 
                   grid_height=args.grid_height)
    if len(args.gpu) > 0:
        torch.cuda.set_device(args.gpu[0])
        model = nn.DataParallel(model, args.gpu).cuda()

    # Create encoder
    encoder = ObjectEncoder()

    # Setup optimizer
    optimizer = optim.SGD(
        model.parameters(), args.lr, args.momentum, args.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay)

    for epoch in range(1, args.epochs+1):

        print('\n=== Beginning epoch {} of {} ==='.format(epoch, args.epochs))
        
        # Update and log learning rate
        scheduler.step(epoch-1)
        summary.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        # Train model
        train(args, train_loader, model, encoder, optimizer, summary, epoch)

        # Run validation every N epochs
        if epoch % args.val_interval == 0:

            
            validate(args, val_loader, model, encoder, summary, epoch)

            # Save model checkpoint
            save_checkpoint(args, epoch, model, optimizer, scheduler)
예제 #2
0
파일: infer.py 프로젝트: prashantraina/oft
def main():

    # Parse command line arguments
    args = parse_args()

    # Load validation dataset to visualise
    dataset = KittiObjectDataset(
        args.root, 'val', args.grid_size, args.grid_res, args.yoffset)
    
    # Build model
    model = OftNet(num_classes=1, frontend=args.frontend, 
                   topdown_layers=args.topdown, grid_res=args.grid_res, 
                   grid_height=args.grid_height)
    if args.gpu >= 0:
        torch.cuda.set_device(args.gpu)
        model.cuda()
    
    # Load checkpoint
    ckpt = torch.load(args.model_path)

    model = nn.DataParallel(model, [args.gpu]).cuda()
    model.load_state_dict(ckpt['model'])

    model = model.module

    # Create encoder
    encoder = ObjectEncoder(nms_thresh=args.nms_thresh)

    # Set up plots
    _, (ax1, ax2) = plt.subplots(nrows=2)
    plt.ion()

    # Iterate over validation images
    for _, image, calib, objects, grid in dataset:

        # Move tensors to gpu
        image = to_tensor(image)
        if args.gpu >= 0:
            image, calib, grid = image.cuda(), calib.cuda(), grid.cuda()

        # Run model forwards
        pred_encoded = model(image[None], calib[None], grid[None])
        
        # Decode predictions
        pred_encoded = [t[0].cpu() for t in pred_encoded]
        detections = encoder.decode(*pred_encoded, grid.cpu())

        image = image.cpu()
        calib = calib.cpu()

        # Visualize predictions
        visualize_objects(image, calib, detections, ax=ax1)
        ax1.set_title('Detections')
        visualize_objects(image, calib, objects, ax=ax2)
        ax2.set_title('Ground truth')

        plt.draw()
        plt.pause(0.01)
        plt.waitforbuttonpress()
예제 #3
0
파일: train.py 프로젝트: pchitale1/oft
def main():
    # Parse command line arguments
    args = parse_args()

    # DLProf - Init PyProf
    if args.dlprof:
        pyprof.init(enable_function_stack=True)
        # Set num epochs to 1 if DLProf is enabled
        args.epochs = 1

    # Create experiment
    summary = _make_experiment(args)

    # Create datasets
    train_data = KittiObjectDataset(args.root, 'train', args.grid_size,
                                    args.grid_res, args.yoffset)
    val_data = KittiObjectDataset(args.root, 'val', args.grid_size,
                                  args.grid_res, args.yoffset)

    # Apply data augmentation
    # train_data = oft.AugmentedObjectDataset(
    #     train_data, args.train_image_size, args.train_grid_size,
    #     jitter=args.grid_jitter)

    # Create dataloaders
    train_loader = DataLoader(train_data,
                              args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              collate_fn=oft.utils.collate)
    val_loader = DataLoader(val_data,
                            args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            collate_fn=oft.utils.collate)

    # Build model
    model = OftNet(num_classes=1,
                   frontend=args.frontend,
                   topdown_layers=args.topdown,
                   grid_res=args.grid_res,
                   grid_height=args.grid_height)
    if len(args.gpu) > 0:
        torch.cuda.set_device(args.gpu[0])
        model = nn.DataParallel(model, args.gpu).cuda()

    # Create encoder
    encoder = ObjectEncoder()

    # Setup optimizer
    optimizer = optim.SGD(model.parameters(), args.lr, args.momentum,
                          args.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay)

    # Creates a GradScaler once at the beginning of training for AMP. Created even if not being used.
    scaler = GradScaler()

    for epoch in range(1, args.epochs + 1):

        print('\n=== Beginning epoch {} of {} ==='.format(epoch, args.epochs))
        # Update and log learning rate
        scheduler.step(epoch - 1)
        summary.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        # Train model
        if args.dlprof:
            with torch.autograd.profiler.emit_nvtx():
                train(args, train_loader, model, encoder, optimizer, summary,
                      epoch, scaler)
        else:
            train(args, train_loader, model, encoder, optimizer, summary,
                  epoch, scaler)

        # Run validation every N epochs
        if epoch % args.val_interval == 0:
            if args.dlprof:
                with torch.autograd.profiler.emit_nvtx():
                    validate(args, val_loader, model, encoder, summary, epoch)
            else:
                validate(args, val_loader, model, encoder, summary, epoch)

            # Save model checkpoint
            save_checkpoint(args, epoch, model, optimizer, scheduler)