Beispiel #1
0
 def test_undeepvo(self):
     model = UnDeepVO()
     input_data = torch.rand(2, 3, 384, 128)
     output = model.depth_net(input_data)
     self.assertEqual(output.shape, torch.Size([2, 1, 384, 128]))
     self.assertTrue(torch.all(output > 0))
     input_data = torch.rand(2, 3, 384, 128)
     rotation, translation = model.pose_net(input_data)
     self.assertEqual(rotation.shape, torch.Size([2, 3]))
     self.assertEqual(translation.shape, torch.Size([2, 3]))
Beispiel #2
0
    def test_temporal_loss(self):
        model = UnDeepVO().to(device)

        left_current_img, right_current_img, left_next_img, right_next_img, \
        left_current_depth, right_current_depth, left_next_depth, right_next_depth, \
        left_current_rotation, left_current_position, right_current_rotation, right_current_position, \
        left_next_rotation, left_next_position, right_next_rotation, right_next_position, \
        src_trans_dst, left_camera_matrix, right_camera_matrix = self.prepare_data_for_tests(model)

        camera_baseline = 0.54
        src_trans_dst[0, 3] = camera_baseline
        focal_length = left_camera_matrix[0, 0, 0]
        transform_from_left_to_right = src_trans_dst
        lambda_position, lambda_angle, lambda_s = 1e-3, 1e-3, 1e-2

        temporal_losses = TemporalImageLosses(left_camera_matrix,
                                              right_camera_matrix)

        out, _ = temporal_losses(left_current_img, left_next_img,
                                 left_current_depth, left_next_depth,
                                 right_current_img, right_next_img,
                                 right_current_depth, right_next_depth,
                                 left_current_position, right_current_position,
                                 left_current_rotation, right_current_rotation,
                                 left_next_position, right_next_position,
                                 left_next_rotation, right_next_rotation)

        self.assertEqual(out.shape, torch.Size([]))
        self.assertFalse(torch.isnan(out))
        self.assertTrue(out > 0)
 def test_unsupervised_depth_problem_cpu(self):
     device = "cpu"
     sequence_8 = Downloader('08')
     if not os.path.exists("./dataset/poses"):
         print("Download dataset")
         sequence_8.download_sequence()
     lengths = (200, 30, 30)
     dataset = pykitti.odometry(sequence_8.main_dir,
                                sequence_8.sequence_id,
                                frames=range(0, 260, 1))
     dataset_manager = DatasetManagerMock(dataset,
                                          lenghts=lengths,
                                          num_workers=WORKERS_COUNT)
     model = UnDeepVO(max_depth=2., min_depth=1.0).to(device)
     optimizer_manger = OptimizerManager()
     criterion = UnsupervisedCriterion(
         dataset_manager.get_cameras_calibration(device), 0.1, 1, 0.85)
     handler = TrainingProcessHandler(mlflow_tags={"name": "test"})
     problem = UnsupervisedDepthProblem(model,
                                        criterion,
                                        optimizer_manger,
                                        dataset_manager,
                                        handler,
                                        batch_size=1,
                                        device=device)
     problem.train(1)
Beispiel #4
0
parser.add_argument('-supervised_lambda',
                    default=0.1,
                    type=float,
                    help='lambda os supervised method')

args = parser.parse_args()

MAIN_DIR = args.main_dir
lengths = args.split
if args.method == "unsupervised":
    dataset = pykitti.odometry(MAIN_DIR,
                               '08',
                               frames=range(*args.frames_range))
    dataset_manager = UnsupervisedDatasetManager(dataset, lenghts=lengths)

    model = UnDeepVO(args.max_depth).cuda()

    criterion = UnsupervisedCriterion(
        dataset_manager.get_cameras_calibration("cuda:0"),
        args.lambda_position, args.lambda_rotation, args.lambda_s,
        args.lambda_disparity)
    handler = TrainingProcessHandler(
        enable_mlflow=True,
        mlflow_tags={"name": args.mlflow_tags_name},
        mlflow_parameters={
            "image_step": args.frames_range[2],
            "max_depth": args.max_depth,
            "epoch": args.epoch,
            "lambda_position": args.lambda_position,
            "lambda_rotation": args.lambda_rotation,
            "lambda_s": args.lambda_s,
Beispiel #5
0
                    help='whether to use resnet or not')

args = parser.parse_args()

MAIN_DIR = args.main_dir
lengths = args.split
problem = None
if args.method == "unsupervised":
    sequence_8 = Downloader('08')
    if not os.path.exists("./dataset/poses"):
        print("Download dataset")
        sequence_8.download_sequence()
    dataset = pykitti.odometry(MAIN_DIR, '08', frames=range(*args.frames_range))
    dataset_manager = UnsupervisedDatasetManager(dataset, lengths=lengths)

    model = UnDeepVO(args.max_depth, args.min_depth, args.resnet).to(args.device)

    if args.model_path != "":
        model.load_state_dict(torch.load(args.model_path, map_location=args.device))
    criterion = UnsupervisedCriterion(dataset_manager.get_cameras_calibration(args.device),
                                      args.lambda_position,
                                      args.lambda_rotation,
                                      args.lambda_s,
                                      args.lambda_disparity,
                                      args.lambda_registration)
    handler = TrainingProcessHandler(enable_mlflow=True, mlflow_tags={"name": args.mlflow_tags_name},
                                     mlflow_parameters={"image_step": args.frames_range[2],
                                                        "max_depth": args.max_depth,
                                                        "epoch": args.epoch,
                                                        "lambda_position": args.lambda_position,
                                                        "lambda_rotation": args.lambda_rotation,