예제 #1
0
파일: train.py 프로젝트: tom-roddick/oft
def main():

    # Parse command line arguments
    args = parse_args()

    # Create experiment
    summary = _make_experiment(args)

    # Create datasets
    train_data = KittiObjectDataset(
        args.root, 'train', args.grid_size, args.grid_res, args.yoffset)
    val_data = KittiObjectDataset(
        args.root, 'val', args.grid_size, args.grid_res, args.yoffset)
    
    # Apply data augmentation
    train_data = oft.AugmentedObjectDataset(
        train_data, args.train_image_size, args.train_grid_size, 
        jitter=args.grid_jitter)

    # Create dataloaders
    train_loader = DataLoader(train_data, args.batch_size, shuffle=True, 
        num_workers=args.workers, collate_fn=oft.utils.collate)
    val_loader = DataLoader(val_data, args.batch_size, shuffle=False, 
        num_workers=args.workers,collate_fn=oft.utils.collate)

    # Build model
    model = OftNet(num_classes=1, frontend=args.frontend, 
                   topdown_layers=args.topdown, grid_res=args.grid_res, 
                   grid_height=args.grid_height)
    if len(args.gpu) > 0:
        torch.cuda.set_device(args.gpu[0])
        model = nn.DataParallel(model, args.gpu).cuda()

    # Create encoder
    encoder = ObjectEncoder()

    # Setup optimizer
    optimizer = optim.SGD(
        model.parameters(), args.lr, args.momentum, args.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay)

    for epoch in range(1, args.epochs+1):

        print('\n=== Beginning epoch {} of {} ==='.format(epoch, args.epochs))
        
        # Update and log learning rate
        scheduler.step(epoch-1)
        summary.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        # Train model
        train(args, train_loader, model, encoder, optimizer, summary, epoch)

        # Run validation every N epochs
        if epoch % args.val_interval == 0:

            
            validate(args, val_loader, model, encoder, summary, epoch)

            # Save model checkpoint
            save_checkpoint(args, epoch, model, optimizer, scheduler)
예제 #2
0
파일: train.py 프로젝트: pchitale1/oft
def main():
    # Parse command line arguments
    args = parse_args()

    # DLProf - Init PyProf
    if args.dlprof:
        pyprof.init(enable_function_stack=True)
        # Set num epochs to 1 if DLProf is enabled
        args.epochs = 1

    # Create experiment
    summary = _make_experiment(args)

    # Create datasets
    train_data = KittiObjectDataset(args.root, 'train', args.grid_size,
                                    args.grid_res, args.yoffset)
    val_data = KittiObjectDataset(args.root, 'val', args.grid_size,
                                  args.grid_res, args.yoffset)

    # Apply data augmentation
    # train_data = oft.AugmentedObjectDataset(
    #     train_data, args.train_image_size, args.train_grid_size,
    #     jitter=args.grid_jitter)

    # Create dataloaders
    train_loader = DataLoader(train_data,
                              args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              collate_fn=oft.utils.collate)
    val_loader = DataLoader(val_data,
                            args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            collate_fn=oft.utils.collate)

    # Build model
    model = OftNet(num_classes=1,
                   frontend=args.frontend,
                   topdown_layers=args.topdown,
                   grid_res=args.grid_res,
                   grid_height=args.grid_height)
    if len(args.gpu) > 0:
        torch.cuda.set_device(args.gpu[0])
        model = nn.DataParallel(model, args.gpu).cuda()

    # Create encoder
    encoder = ObjectEncoder()

    # Setup optimizer
    optimizer = optim.SGD(model.parameters(), args.lr, args.momentum,
                          args.weight_decay)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, args.lr_decay)

    # Creates a GradScaler once at the beginning of training for AMP. Created even if not being used.
    scaler = GradScaler()

    for epoch in range(1, args.epochs + 1):

        print('\n=== Beginning epoch {} of {} ==='.format(epoch, args.epochs))
        # Update and log learning rate
        scheduler.step(epoch - 1)
        summary.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)

        # Train model
        if args.dlprof:
            with torch.autograd.profiler.emit_nvtx():
                train(args, train_loader, model, encoder, optimizer, summary,
                      epoch, scaler)
        else:
            train(args, train_loader, model, encoder, optimizer, summary,
                  epoch, scaler)

        # Run validation every N epochs
        if epoch % args.val_interval == 0:
            if args.dlprof:
                with torch.autograd.profiler.emit_nvtx():
                    validate(args, val_loader, model, encoder, summary, epoch)
            else:
                validate(args, val_loader, model, encoder, summary, epoch)

            # Save model checkpoint
            save_checkpoint(args, epoch, model, optimizer, scheduler)