num_workers=4)

    val_dataset = DatasetKITTIVal(kitti_depth_path=kitti_depth_path)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=val_batch_size,
                                             shuffle=False,
                                             num_workers=1)

    criterion = MaskedL2Gauss().cuda()
    rmse_criterion = RMSE().cuda()

    model = DepthCompletionNet().cuda()
    model = torch.nn.DataParallel(model)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=weight_decay)
    optimizer.zero_grad()

    train_losses = []
    batch_train_losses = []
    val_losses = []
    train_rmses = []
    batch_train_rmses = []
    val_rmses = []
    for i_iter, batch in enumerate(train_loader):
        imgs, sparses, targets, file_ids = batch
        imgs = Variable(imgs.cuda())  # (shape: (batch_size, h, w))
        sparses = Variable(sparses.cuda())  # (shape: (batch_size, h, w))
        targets = Variable(targets.cuda())  # (shape: (batch_size, h, w))
예제 #2
0
        os.makedirs(snapshot_dir)

    train_dataset = DatasetVirtualKITTIAugmentation(virtualkitti_path=virtualkitti_path, max_iters=num_steps*batch_size, crop_size=(352, 352))
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    val_dataset = DatasetVirtualKITTIVal(virtualkitti_path=virtualkitti_path)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=1)

    criterion = MaskedL2Gauss().cuda()
    rmse_criterion = RMSE().cuda()

    model = DepthCompletionNet().cuda()
    model = torch.nn.DataParallel(model)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    optimizer.zero_grad()

    train_losses = []
    batch_train_losses = []
    val_losses = []
    train_rmses = []
    batch_train_rmses = []
    val_rmses = []
    for i_iter, batch in enumerate(train_loader):
        imgs, sparses, targets, file_ids = batch
        imgs = Variable(imgs.cuda()) # (shape: (batch_size, h, w))
        sparses = Variable(sparses.cuda()) # (shape: (batch_size, h, w))
        targets = Variable(targets.cuda()) # (shape: (batch_size, h, w))

        means, log_vars = model(imgs, sparses) # (both of shape: (batch_size, 1, h, w))