Ejemplo n.º 1
0
    def forward(self, dset, epoch, phase):
        dev = self.device
        start = time.time()
        self.depth_model.train(False)
        self.depth_model.eval()
        
        if phase == 'train':
            self.plane_model.train(True)
        else:
            self.plane_model.train(False)
            self.plane_model.eval()

        dset_size = dset.dataset.__len__()
        running_loss = 0.0           
            # Iterate over data.
        for data in dset:
            target_img, source_imgs, lie_alg, intrinsics, _ = data
            target_img_aug = target_img['color_aug_left'].to(dev)
            lie_alg = lie_alg['color_aug']

            source_img_list = []
            source_img_aug_list = []
            gt_lie_alg_list = []
            vo_lie_alg_list = []
            dt_list = []
            for i, im, in enumerate(source_imgs['color_aug_left']):
                source_img_aug_list.append(im.to(dev))
                source_img_list.append(source_imgs['color_left'][i].to(dev))
                gt_lie_alg_list.append(lie_alg[i][0].type(torch.FloatTensor).to(dev))
                vo_lie_alg_list.append(lie_alg[i][1].type(torch.FloatTensor).to(dev))
                dt_list.append(lie_alg[i][3].type(torch.FloatTensor).to(dev).expand_as(vo_lie_alg_list[-1][:,0:3]))

            intrinsics_aug = intrinsics['color_aug_left'].type(torch.FloatTensor).to(dev)[:,0,:,:] #only need one matrix since it's constant across the training sample
            intrinsics = intrinsics['color_left'].type(torch.FloatTensor).to(dev)[:,0,:,:]

            
            disparity = self.depth_model(target_img_aug, epoch=epoch)
            _,depth = disp_to_depth(disparity[0], self.config['min_depth'], self.config['max_depth'])
            self.optimizer.zero_grad()
            minibatch_loss = 0
            with torch.set_grad_enabled(phase == 'train'):
                plane_est = self.plane_model(target_img_aug, epoch=epoch)
                plane_est = plane_est[0]

                ones_var = torch.ones(1).expand_as(plane_est).type_as(plane_est)
                reg_loss = torch.nn.functional.binary_cross_entropy(plane_est, ones_var,reduction='none')  

                minibatch_loss = 0.05*reg_loss.mean() + 25*self.plane_loss(plane_est, depth, intrinsics_aug.inverse())
                #higher weight on plane loss makes plane estimate more conservative
                if phase == 'train':
                    minibatch_loss.backward()        
                    self.optimizer.step()
            running_loss += minibatch_loss.item()
            

        epoch_loss = running_loss / float(dset_size)

        print('{} Loss: {:.6f}'.format(phase, epoch_loss))
        print('{} epoch completed in {} seconds.'.format(phase, timeSince(start)))
        return epoch_loss
Ejemplo n.º 2
0
    def forward(self, dset, epoch, phase):
        dev = self.device
        start = time.time()
        self.depth_model.train(False)
        self.depth_model.eval()
        
        if phase == 'train':
            self.plane_model.train(True)
        else:
            self.plane_model.train(False)
            self.plane_model.eval()

        dset_size = dset.dataset.__len__()
        running_loss = 0.0           
            # Iterate over data.
        for data in dset:
            target_img, source_img_list, gt_lie_alg_list, vo_lie_alg_list, flow_imgs, intrinsics, target_img_aug, \
                source_img_aug_list, gt_lie_alg_aug_list, vo_lie_alg_aug_list, intrinsics_aug = process_sample_batch(data, self.config)
                
            
            disparity = self.depth_model(target_img_aug, epoch=epoch)
            _,depth = disp_to_depth(disparity[0], self.config['min_depth'], self.config['max_depth'])
            self.optimizer.zero_grad()
            minibatch_loss = 0
            with torch.set_grad_enabled(phase == 'train'):
                plane_est = self.plane_model(target_img_aug, epoch=epoch)
                plane_est = plane_est[0]

                ones_var = torch.ones(1).expand_as(plane_est).type_as(plane_est)
                reg_loss = torch.nn.functional.binary_cross_entropy(plane_est, ones_var,reduction='none')  

                minibatch_loss = 0.05*reg_loss.mean() + 25*self.plane_loss(plane_est, depth, intrinsics_aug.inverse())
                #higher weight on plane loss makes plane estimate more conservative
                if phase == 'train':
                    minibatch_loss.backward()        
                    self.optimizer.step()
            running_loss += minibatch_loss.item()
            

        epoch_loss = running_loss / float(dset_size)

        print('{} Loss: {:.6f}'.format(phase, epoch_loss))
        print('{} epoch completed in {} seconds.'.format(phase, timeSince(start)))
        return epoch_loss
Ejemplo n.º 3
0
    def forward(self,
                source_imgs,
                target_imgs,
                poses,
                disparity,
                intrinsics,
                pose_vec_weight=None,
                validate=False,
                epoch=5):
        ''' Adopting from https://github.com/JiawangBian/SC-SfMLearner-Release/blob/master/loss_functions.py '''
        zero = torch.zeros(1).type_as(intrinsics)
        losses = {'l_reconstruct_inverse': zero.clone(), 'l_reconstruct_forward': zero.clone(), 'l_depth': zero.clone(), 'l_smooth': zero.clone(), \
            'l_plane': zero.clone(), 'l_left_right_consist': zero.clone(), 'l_brightness': zero.clone()  }
        disparity, source_disparities = disparity[0], disparity[
            1:]  #separate disparity list into source and target disps
        poses, poses_inv = poses[0], poses[
            1]  #separate pose change predictions
        target_img = target_imgs['color_left'].to(self.config['device'])
        B, _, h, w = target_img.size()

        if self.config[
                'l_camera_height'] and epoch > 0:  #keep out of loop, only need to compute once
            plane_est = self.plane_model(target_img, epoch=epoch)[0].detach()
            plane_est = nn.functional.interpolate(plane_est,
                                                  (int(h / 4), int(w / 4)),
                                                  mode='bilinear')
            int_inv = intrinsics.clone()
            int_inv[:, 0:2, :] = int_inv[:, 0:2, :] / 4
            int_inv = int_inv.inverse()

        for scale, disp in enumerate(disparity):
            #upsample and convert to depth
            if scale != 0:
                disp = nn.functional.interpolate(disp, (h, w), mode='nearest')
            _, d = disp_to_depth(disp, self.config['min_depth'],
                                 self.config['max_depth'])
            # print(d[:, 0, -40:, int(d.size(3)/2.)].mean())

            if self.config['l_left_right_consist']:
                losses[
                    'l_left_right_consist'] += self.l_left_right_consist_weight * self.l_lr_consist(
                        target_img, target_imgs['color_right'].to(
                            self.config['device']), d, intrinsics)

            ## Disparity Smoothness Loss
            if self.config['l_smooth']:
                losses['l_smooth'] += (self.l_smooth_weight * get_smooth_loss(
                    disp, target_img)) / (2**scale)
            '''Ground Plane Loss (experimental)'''
            if self.config['l_camera_height'] and epoch > 0:
                depth_down = nn.functional.interpolate(
                    d, (int(h / 4), int(w / 4)), mode='bilinear')
                scale_factor, plane_loss = self.plane_loss(
                    plane_est, depth_down, int_inv, disp)
                self.scale_factor_list[epoch].append(
                    scale_factor.mean().item())
                losses['l_plane'] += self.l_camera_height_weight * plane_loss

                for pose in poses:
                    target_pose = (
                        pose[:, 0:3].clone() * (scale_factor.reshape(
                            (-1, 1)).expand_as(pose[:, 0:3]))).detach()
                    losses['l_plane'] += 0.6 * (pose[:, 0:3] -
                                                target_pose).abs().mean()

            reconstruction_errors = []
            masks = []
            proj_depths = []
            if self.config['l_reconstruction']:
                for j, source_img in enumerate(source_imgs):
                    pose, pose_inv = poses[j], poses_inv[j]
                    source_disparity = source_disparities[j][scale]
                    if scale != 0:
                        source_disparity = nn.functional.interpolate(
                            source_disparity, (h, w), mode='nearest')
                    _, source_d = disp_to_depth(source_disparity,
                                                self.config['min_depth'],
                                                self.config['max_depth'])

                    ## Disparity Smoothness Loss
                    if self.config['l_smooth']:
                        losses['l_smooth'] += (
                            self.l_smooth_weight * get_smooth_loss(
                                source_disparity, source_img)) / (2**scale)
                    '''inverse reconstruction - reproject target frame to source frames'''
                    if self.config['l_inverse']:
                        l_reprojection, l_depth, _, _ = self.compute_pairwise_loss(
                            source_img, target_img, source_d, d,
                            -pose_inv.clone(), intrinsics, epoch)

                        if self.config['l_depth_consist']:
                            losses[
                                'l_depth'] += self.l_depth_consist_weight * l_depth
                        losses['l_reconstruct_inverse'] += 0.3 * l_reprojection
                    '''forward reconstruction - reproject source frames to target frame'''
                    l_reprojection, l_depth, diff_img, valid_mask = self.compute_pairwise_loss(
                        target_img, source_img, d, source_d, -pose.clone(),
                        intrinsics, epoch)

                    if self.config['l_depth_consist']:
                        losses[
                            'l_depth'] += self.l_depth_consist_weight * l_depth

                    reconstruction_errors.append(diff_img)
                    masks.append(valid_mask)
                reconstruction_errors = torch.cat(reconstruction_errors, 1)
                reconstruction_errors, idx = torch.min(reconstruction_errors,
                                                       1)
                losses['l_reconstruct_forward'] += reconstruction_errors.mean()

        losses['total'] = 0
        for key, value in losses.items():
            if key is not 'total':
                losses[key] = value / (self.num_scales)
                losses['total'] += losses[key]

        return losses
Ejemplo n.º 4
0
                        None for i in range(0, len(source_img_list))
                    ] for i in range(0, 2)]  #annoying but necessary
                    flow_imgs_fwd_list = [
                        None for i in range(0, len(source_img_list))
                    ]
                    flow_imgs_back_list = [
                        None for i in range(0, len(source_img_list))
                    ]

                intrinsics = intrinsics.type(torch.FloatTensor).to(
                    device
                )[:,
                  0, :, :]  #only need one matrix since it's constant across the training sample
                disparity = depth_model(target_img)
                disp = disparity[0]
                _, depth = disp_to_depth(disp, config['min_depth'],
                                         config['max_depth'])

                poses = [
                    compute_pose(pose_model,
                                 [target_img, source_img, flow_img_fwd], vo,
                                 dpc, mode, 50)
                    for source_img, vo, flow_img_fwd in zip(
                        source_img_list, vo_lie_alg_list, flow_imgs_fwd_list)
                ]
                poses_inv = [
                    compute_pose(pose_model,
                                 [source_img, target_img, flow_img_back], -vo,
                                 dpc, mode, 50)
                    for source_img, vo, flow_img_back in zip(
                        source_img_list, vo_lie_alg_list, flow_imgs_back_list)
                ]
                pose_results = {'source1': {}, 'source2': {}}

                batch_size = target_img.shape[0]
                imgs = torch.cat([target_img, source_img_list[0]], 0)
                disparities = depth_model(imgs, epoch=50)

                target_disparities = [
                    disp[0:batch_size] for disp in disparities
                ]
                source_disp_1 = [
                    disp[batch_size:(2 * batch_size)] for disp in disparities
                ]

                disparities = [target_disparities, source_disp_1]
                depths = [
                    disp_to_depth(disp[0], config['min_depth'],
                                  config['max_depth'])[1]
                    for disp in disparities
                ]  ####.detach()

                flow_imgs_fwd_list, flow_imgs_back_list = flow_imgs
                poses, poses_inv = solve_pose(pose_model, target_img,
                                              source_img_list, flow_imgs)
                fwd_pose_vec1, inv_pose_vec1 = poses[0].clone(
                ), poses_inv[0].clone()

                depth = 30 * depths[0]
                fwd_pose_vec1[:, 0:3] = 30 * fwd_pose_vec1[:, 0:3]
                inv_pose_vec1[:, 0:3] = 30 * inv_pose_vec1[:, 0:3]

                if plane_rescaling == True:
                    plane_est = plane_model(target_img, epoch=50)[0].detach()
        for k, data in enumerate(test_dset_loaders):
            target_img, source_imgs, lie_alg, intrinsics, flow_imgs = data
            target_img, source_imgs, intrinsics = target_img[
                'color_left'], source_imgs['color_left'], intrinsics[
                    'color_left']
            target_img = target_img.to(device)
            B = target_img.shape[0]

            if post_process == True:
                # Post-processed results require each image to have two forward passes
                target_img = torch.cat(
                    (target_img, torch.flip(target_img, [3])), 0)

            disparities = depth_model(target_img, epoch=50)

            disps, depths = disp_to_depth(disparities[0], config['min_depth'],
                                          config['max_depth'])

            if plane_rescaling == True:
                plane_est = plane_model(target_img[0:B], epoch=50)[0].detach()
                intrinsics = intrinsics[:, 0].type(
                    torch.FloatTensor).to(device).clone()
                scale_factor = scale_recovery(plane_est,
                                              depths[0:B],
                                              intrinsics,
                                              h_gt=cam_height / 30.)
                scale_factor_list.append(scale_factor.cpu().numpy())

            pred_disp = disps.cpu()[:, 0].numpy()

            if post_process == True:
                N = pred_disp.shape[0] // 2