depth_estimation_model_student = models.FCDenseNet57(n_classes=1) # Initialize the depth estimation network with Kaiming He initialization depth_estimation_model_student = utils.init_net(depth_estimation_model_student, type="kaiming", mode="fan_in", activation_mode="relu", distribution="normal") # Multi-GPU running depth_estimation_model_student = torch.nn.DataParallel(depth_estimation_model_student) # Summary network architecture if display_architecture: torchsummary.summary(depth_estimation_model_student, input_size=(3, height, width)) # Optimizer optimizer = torch.optim.SGD(depth_estimation_model_student.parameters(), lr=max_lr, momentum=0.9) lr_scheduler = scheduler.CyclicLR(optimizer, base_lr=min_lr, max_lr=max_lr, step_size=num_iter) # Custom layers depth_scaling_layer = models.DepthScalingLayer(epsilon=depth_scaling_epsilon) depth_warping_layer = models.DepthWarpingLayer(epsilon=depth_warping_epsilon) flow_from_depth_layer = models.FlowfromDepthLayer() # Loss functions sparse_flow_loss_function = losses.SparseMaskedL1Loss() depth_consistency_loss_function = losses.NormalizedDistanceLoss(height=height, width=width) # Load previous student model, lr scheduler, and so on if load_trained_model: if Path(trained_model_path).exists(): print("Loading {:s} ...".format(trained_model_path)) state = torch.load(trained_model_path) step = state['step'] epoch = state['epoch'] depth_estimation_model_student.load_state_dict(state['model']) print('Restored model, epoch {}, step {}'.format(epoch, step))
# Load trained model if trained_model_path.exists(): print("Loading {:s} ...".format(str(trained_model_path))) state = torch.load(str(trained_model_path)) step = state['step'] epoch = state['epoch'] depth_estimation_model.load_state_dict(state['model']) print('Restored model, epoch {}, step {}'.format(epoch, step)) else: print("Trained model could not be found") raise OSError depth_estimation_model = depth_estimation_model.module # Custom layers depth_scaling_layer = models.DepthScalingLayer() depth_warping_layer = models.DepthWarpingLayer() flow_from_depth_layer = models.FlowfromDepthLayer() with torch.no_grad(): # Set model to evaluation mode depth_estimation_model.eval() # Update progress bar tq = tqdm.tqdm(total=len(test_loader) * batch_size) for batch, (colors_1, colors_2, sparse_depths_1, sparse_depths_2, sparse_depth_masks_1, sparse_depth_masks_2, sparse_flows_1, sparse_flows_2, sparse_flow_masks_1, sparse_flow_masks_2, boundaries, rotations_1_wrt_2, rotations_2_wrt_1, translations_1_wrt_2, translations_2_wrt_1, intrinsics, folders) in enumerate(test_loader):