예제 #1
0
    depth_estimation_model_student = models.FCDenseNet57(n_classes=1)
    # Initialize the depth estimation network with Kaiming He initialization
    depth_estimation_model_student = utils.init_net(depth_estimation_model_student, type="kaiming", mode="fan_in",
                                                    activation_mode="relu",
                                                    distribution="normal")
    # Multi-GPU running
    depth_estimation_model_student = torch.nn.DataParallel(depth_estimation_model_student)
    # Summary network architecture
    if display_architecture:
        torchsummary.summary(depth_estimation_model_student, input_size=(3, height, width))
    # Optimizer
    optimizer = torch.optim.SGD(depth_estimation_model_student.parameters(), lr=max_lr, momentum=0.9)
    lr_scheduler = scheduler.CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, step_size=num_iter)

    # Custom layers
    depth_scaling_layer = models.DepthScalingLayer(epsilon=depth_scaling_epsilon)
    depth_warping_layer = models.DepthWarpingLayer(epsilon=depth_warping_epsilon)
    flow_from_depth_layer = models.FlowfromDepthLayer()
    # Loss functions
    sparse_flow_loss_function = losses.SparseMaskedL1Loss()
    depth_consistency_loss_function = losses.NormalizedDistanceLoss(height=height, width=width)

    # Load previous student model, lr scheduler, and so on
    if load_trained_model:
        if Path(trained_model_path).exists():
            print("Loading {:s} ...".format(trained_model_path))
            state = torch.load(trained_model_path)
            step = state['step']
            epoch = state['epoch']
            depth_estimation_model_student.load_state_dict(state['model'])
            print('Restored model, epoch {}, step {}'.format(epoch, step))
예제 #2
0
        # Load trained model
        if trained_model_path.exists():
            print("Loading {:s} ...".format(str(trained_model_path)))
            state = torch.load(str(trained_model_path))
            step = state['step']
            epoch = state['epoch']
            depth_estimation_model.load_state_dict(state['model'])
            print('Restored model, epoch {}, step {}'.format(epoch, step))
        else:
            print("Trained model could not be found")
            raise OSError
        depth_estimation_model = depth_estimation_model.module

        # Custom layers
        depth_scaling_layer = models.DepthScalingLayer()
        depth_warping_layer = models.DepthWarpingLayer()
        flow_from_depth_layer = models.FlowfromDepthLayer()

        with torch.no_grad():
            # Set model to evaluation mode
            depth_estimation_model.eval()
            # Update progress bar
            tq = tqdm.tqdm(total=len(test_loader) * batch_size)
            for batch, (colors_1, colors_2, sparse_depths_1, sparse_depths_2,
                        sparse_depth_masks_1, sparse_depth_masks_2,
                        sparse_flows_1, sparse_flows_2, sparse_flow_masks_1,
                        sparse_flow_masks_2, boundaries, rotations_1_wrt_2,
                        rotations_2_wrt_1, translations_1_wrt_2,
                        translations_2_wrt_1, intrinsics,
                        folders) in enumerate(test_loader):