Example #1
0
    def set_layers(self):
        """properly handle layer initialization under multiple dataset situation
        """
        self.backproject_depth = {}
        self.project_3d = {}
        if self.opt.selfocclu:
            self.selfOccluMask = SelfOccluMask().cuda()

        for n, scale in enumerate(self.opt.scales):
            h = self.opt.height // (2**scale)
            w = self.opt.width // (2**scale)

            self.backproject_depth[scale] = BackprojectDepth(
                self.opt.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
            self.project_3d[scale].to(self.device)

        if self.opt.bnMorphLoss:
            from bnmorph.bnmorph import BNMorph
            self.tool = grad_computation_tools(batch_size=self.opt.batch_size,
                                               height=self.opt.height,
                                               width=self.opt.width).cuda()

            self.auto_morph = BNMorph(height=self.opt.height,
                                      width=self.opt.width,
                                      senseRange=20).cuda()
            self.textureMeasure = TextureIndicatorM().cuda()
Example #2
0
 def __init__(self, height, width):
     self.height = height
     self.width = width
     self.tool = grad_computation_tools(batch_size=1, height=self.height,
                                        width=self.width).cuda()
     self.auto_morph = BNMorph(height=self.height, width=self.width, senseRange=20).cuda()
     self.tool.disparityTh = 0.07
     self.dismetric_seman = list()
     self.dismetric_disp = list()
Example #3
0
    def set_layers(self):
        """properly handle layer initialization under multiple dataset situation
        """
        self.backproject_depth = {}
        self.project_3d = {}
        self.selfOccluMask = SelfOccluMask().cuda()

        height = self.opt.height
        width = self.opt.width
        for n, scale in enumerate(self.opt.scales):
            h = height // (2 ** scale)
            w = width // (2 ** scale)

            self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
            self.project_3d[scale].to(self.device)

        if self.opt.bnMorphLoss:
            from bnmorph.bnmorph import BNMorph
            self.tool = grad_computation_tools(batch_size=self.opt.batch_size, height=self.opt.height,
                                               width=self.opt.width).cuda()

            self.bnmorph = BNMorph(height=self.opt.height, width=self.opt.width, senseRange=20).cuda()
            self.foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17,
                                   18]  # pole, traffic light, traffic sign, person, rider, car, truck, bus, train, motorcycle, bicycle
            self.textureMeasure = TextureIndicatorM().cuda()

        if self.opt.Dloss:
            from eppl_render.eppl_render import EpplRender
            self.epplrender = EpplRender(height=self.opt.height, width=self.opt.width, batch_size=self.opt.batch_size, sampleNum=self.opt.eppsm).cuda()
Example #4
0
def evaluate(opt):
    """Evaluates a pretrained model using a specified test set
    """
    MIN_DEPTH = 1e-3
    MAX_DEPTH = 80
    selfOccluMask = SelfOccluMask().cuda()
    selfOccluMask.th = 0
    if opt.isCudaMorphing and opt.borderMorphLoss:
        bnmorph = BNMorph(height=opt.height, width=opt.width,
                          sparsityRad=2).cuda()
    assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \
        "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo"

    if opt.ext_disp_to_eval is None:

        opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)

        assert os.path.isdir(opt.load_weights_folder), \
            "Cannot find a folder at {}".format(opt.load_weights_folder)

        filenames = readlines(
            os.path.join(splits_dir, opt.split_name,
                         opt.appendix_name + ".txt"))
        encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth")
        decoder_path = os.path.join(opt.load_weights_folder, "depth.pth")

        encoder_dict = torch.load(encoder_path)

        dataset = datasets.KITTIRAWDataset(
            opt.data_path,
            filenames,
            encoder_dict['height'],
            encoder_dict['width'], [0, 's'],
            4,
            is_train=False,
            tag=opt.dataset,
            img_ext='png',
            load_meta=opt.load_meta,
            is_load_semantics=opt.use_kitti_gt_semantics,
            is_predicted_semantics=opt.is_predicted_semantics)

        dataloader = DataLoader(dataset,
                                2,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                drop_last=False)

        encoder = networks.ResnetEncoder(opt.num_layers,
                                         False,
                                         num_input_images=2)
        depth_decoder = networks.DepthDecoder(
            encoder.num_ch_enc,
            isSwitch=(opt.switchMode == 'on'),
            isMulChannel=opt.isMulChannel,
            outputtwoimage=(opt.outputtwoimage == True))

        model_dict = encoder.state_dict()
        encoder.load_state_dict(
            {k: v
             for k, v in encoder_dict.items() if k in model_dict})
        depth_decoder.load_state_dict(torch.load(decoder_path))

        encoder.cuda()
        encoder.eval()
        depth_decoder.cuda()
        depth_decoder.eval()

        pred_disps = []
        mergeDisp = Merge_MultDisp(opt.scales, batchSize=opt.batch_size)

        count = 0
        tottime = 0

        if not os.path.isdir(opt.output_dir):
            os.mkdir(opt.output_dir)

        with torch.no_grad():
            for data in dataloader:
                # input_colorl = torch.cat([data[("color", 0, 0)], data[("color", 's', 0)]], dim=1).cuda()
                # input_colorr = torch.cat([data[("color", 's', 0)], data[("color", 0, 0)]], dim=1).cuda()
                # input_color = torch.cat([input_colorl, input_colorr], dim=0)
                start = time.time()
                input_color = torch.cat(
                    [data[("color", 0, 0)], data[("color", 's', 0)]],
                    dim=1).cuda()
                # tensor2rgb(input_color[:,0:3,:,:], ind=0).show()
                # tensor2rgb(input_color[:, 3:6, :, :], ind=0).show()
                # tensor2rgb(input_color[:, 0:3, :, :], ind=1).show()

                features = encoder(input_color)
                outputs = dict()
                outputs.update(
                    depth_decoder(features,
                                  computeSemantic=False,
                                  computeDepth=True))

                mergeDisp(data, outputs, eval=True)

                count = count + 1
                scaled_disp, _ = disp_to_depth(outputs[("disp", 0)],
                                               opt.min_depth, opt.max_depth)
                pred_disp = scaled_disp
                pred_disp = pred_disp.cpu()[:, 0].numpy()

                real_scale_disp = scaled_disp * (torch.abs(
                    data[("K", 0)][:, 0, 0] * data["stereo_T"][:, 0, 3]).view(
                        opt.batch_size, 1, 1,
                        1).expand_as(scaled_disp)).cuda()
                SSIMMask = selfOccluMask(real_scale_disp,
                                         data["stereo_T"][:, 0, 3].cuda())

                store_path = filenames[data['idx'][0].numpy()].split(' ')
                folder1 = os.path.join(opt.output_dir,
                                       store_path[0].split('/')[0])
                folder2 = os.path.join(opt.output_dir, store_path[0])
                folder3 = os.path.join(folder2, 'image_02')
                folder4 = os.path.join(folder2, 'image_03')
                if not os.path.isdir(folder1):
                    os.mkdir(folder1)
                if not os.path.isdir(folder2):
                    os.mkdir(folder2)
                if not os.path.isdir(folder3):
                    os.mkdir(folder3)
                if not os.path.isdir(folder4):
                    os.mkdir(folder4)
                if opt.outputvisualizaiton:
                    folder5 = os.path.join(folder2, 'image_02_compose')
                    folder6 = os.path.join(folder2, 'image_03_compose')
                    if not os.path.isdir(folder5):
                        os.mkdir(folder5)
                    if not os.path.isdir(folder6):
                        os.mkdir(folder6)
                    a = outputs[("disp", 0)] * (1 - SSIMMask)
                    fig1 = tensor2disp(a, ind=0, vmax=0.15)
                    fig2 = tensor2disp(a, ind=1, vmax=0.15)
                    fig1.save(
                        os.path.join(folder5,
                                     store_path[1].zfill(10) + '.png'))
                    fig2.save(
                        os.path.join(folder6,
                                     store_path[1].zfill(10) + '.png'))
                pathl = os.path.join(folder3, store_path[1].zfill(10) + '.png')
                pathr = os.path.join(folder4, store_path[1].zfill(10) + '.png')

                # fig1 = tensor2disp(outputs[("disp", 0)], ind=1, vmax=0.1)
                # fig2 = tensor2disp(outputs[("disp", 0)] * (1 - SSIMMask), ind=1, vmax=0.1)
                # fig_combined = np.concatenate([np.array(fig1), np.array(fig2)], axis=0)
                # pil.fromarray(fig_combined).show()
                real_scale_disp = real_scale_disp * (1 - SSIMMask)
                stored_disp = real_scale_disp / 960
                save_loss(stored_disp[0, 0, :, :].cpu().numpy(), pathl)
                save_loss(stored_disp[1, 0, :, :].cpu().numpy(), pathr)

                duration = time.time() - start
                tottime = tottime + duration
                print("left time %f hours" %
                      (tottime / count * (len(filenames) - count) / 60 / 60))
Example #5
0
def evaluate(opt):
    """Evaluates a pretrained model using a specified test set
    """
    MIN_DEPTH = 1e-3
    MAX_DEPTH = 80
    if opt.isCudaMorphing and opt.borderMorphLoss:
        bnmorph = BNMorph(height=opt.height, width=opt.width, sparsityRad=2).cuda()
    assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \
        "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo"

    if opt.ext_disp_to_eval is None:

        opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)

        assert os.path.isdir(opt.load_weights_folder), \
            "Cannot find a folder at {}".format(opt.load_weights_folder)

        # print("-> Loading weights from {}".format(opt.load_weights_folder))
        if not opt.UseCustTest:
            filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt"))
        else:
            filenames = readlines(os.path.join(splits_dir, "eigen_test_toy", "val_files.txt"))
        encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth")
        decoder_path = os.path.join(opt.load_weights_folder, "depth.pth")

        encoder_dict = torch.load(encoder_path)

        dataset = datasets.KITTIRAWDataset(opt.data_path, filenames,
                                           encoder_dict['height'], encoder_dict['width'],
                                           [0], 4, is_train=False, tag=opt.dataset, img_ext = 'png', load_meta=opt.load_meta, is_load_semantics=opt.use_kitti_gt_semantics, is_predicted_semantics = opt.is_predicted_semantics)

        dataloader = DataLoader(dataset, opt.batch_size, shuffle=False, num_workers=opt.num_workers, drop_last=True)

        encoder = networks.ResnetEncoder(opt.num_layers, False)
        depth_decoder = networks.DepthDecoder(encoder.num_ch_enc, isSwitch=(opt.switchMode == 'on'), isMulChannel=opt.isMulChannel)

        if opt.borderMorphLoss:
            tool = grad_computation_tools(batch_size=opt.batch_size, height=opt.height, width=opt.width).cuda()
            auto_morph = AutoMorph(height=opt.height, width=opt.width)
            foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18]  # pole, traffic light, traffic sign, person, rider, car, truck, bus, train, motorcycle, bicycle
            MorphitNum = 5


        model_dict = encoder.state_dict()
        encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict})
        depth_decoder.load_state_dict(torch.load(decoder_path))

        encoder.cuda()
        encoder.eval()
        depth_decoder.cuda()
        depth_decoder.eval()

        if opt.set_eval_train:
            encoder.train()
            depth_decoder.train()

        # encoder.train()
        # depth_decoder.train()

        pred_disps = []
        mergeDisp = Merge_MultDisp(opt.scales, batchSize = opt.batch_size)

        # print("-> Computing predictions with size {}x{}".format(
        #     encoder_dict['width'], encoder_dict['height']))
        count = 0
        with torch.no_grad():
            for data in dataloader:
                input_color = data[("color", 0, 0)].cuda()

                if opt.post_process:
                    # Post-processed results require each image to have two forward passes
                    input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0)

                features = encoder(input_color)
                outputs = dict()
                # outputs.update(depth_decoder(features, computeSemantic=True, computeDepth=False))
                outputs.update(depth_decoder(features, computeSemantic=False, computeDepth=True))

                mergeDisp(data, outputs, eval=True)
                # outputs['disp', 0] = F.interpolate(outputs['disp', 0], [opt.height, opt.width], mode='bilinear', align_corners=True)
                # pickle.dump(outputs, open("eval_outputs.p", "wb"))
                if opt.borderMorphLoss:
                    for key, ipt in data.items():
                        if not (key == 'height' or key == 'width' or key == 'tag' or key == 'cts_meta' or key == 'file_add'):
                            data[key] = ipt.to(torch.device("cuda"))

                    foregroundMapGt = torch.ones([opt.batch_size, 1, opt.height, opt.width],
                                                 dtype=torch.uint8, device=torch.device("cuda"))
                    for m in foregroundType:
                        foregroundMapGt = foregroundMapGt * (data['seman_gt'] != m)
                    foregroundMapGt = (1 - foregroundMapGt).float()

                    disparity_grad = torch.abs(tool.convDispx(outputs['disp', 0])) + torch.abs(
                        tool.convDispy(outputs['disp', 0]))
                    semantics_grad = torch.abs(tool.convDispx(foregroundMapGt)) + torch.abs(
                        tool.convDispy(foregroundMapGt))
                    disparity_grad = disparity_grad * tool.zero_mask
                    semantics_grad = semantics_grad * tool.zero_mask

                    disparity_grad_bin = disparity_grad > tool.disparityTh
                    semantics_grad_bin = semantics_grad > tool.semanticsTh

                    if opt.isCudaMorphing:
                        morphedx, morphedy, coeff = bnmorph.find_corresponding_pts(disparity_grad_bin, semantics_grad_bin)
                        morphedx = (morphedx / (opt.width - 1) - 0.5) * 2
                        morphedy = (morphedy / (opt.height - 1) - 0.5) * 2
                        grid = torch.cat([morphedx, morphedy], dim=1).permute(0, 2, 3, 1)
                        dispMaps_morphed = F.grid_sample(outputs['disp', 0], grid, padding_mode="border")
                    else:
                        disparity_grad_bin = disparity_grad_bin.detach().cpu().numpy()
                        semantics_grad_bin = semantics_grad_bin.detach().cpu().numpy()

                        disparityMap_to_processed = outputs['disp', 0].detach().cpu().numpy()
                        dispMaps_morphed = list()
                        changeingRecs = list()
                        for mm in range(opt.batch_size):
                            dispMap_morphed, changeingRec = auto_morph.automorph(
                                disparity_grad_bin[mm, 0, :, :], semantics_grad_bin[mm, 0, :, :],
                                disparityMap_to_processed[mm, 0, :, :])
                            dispMaps_morphed.append(dispMap_morphed)
                            changeingRecs.append(changeingRec)
                        dispMaps_morphed = torch.from_numpy(np.stack(dispMaps_morphed, axis=0)).unsqueeze(1).cuda()
                    outputs[("disp", 0)] = dispMaps_morphed
                    # tensor2disp(dispMaps_morphed, ind=0, vmax=0.09).show()

                # print(count)
                count = count + 1
                pred_disp, _ = disp_to_depth(outputs[("disp", 0)], opt.min_depth, opt.max_depth)
                pred_disp = pred_disp.cpu()[:, 0].numpy()

                # Some check:
                # with open('train_outputs.p', 'rb') as handle:
                #     train_outputs = pickle.load(handle)
                #     pred_disp, pdepth = disp_to_depth(outputs[("disp", 0)], opt.min_depth, opt.max_depth)
                #     torch.mean(torch.abs(train_outputs[('disp', 0)] - outputs[("disp", 0)]))
                #     torch.mean(torch.abs(train_outputs[('depth', 0, 0)] - pdepth))

                if opt.post_process:
                    N = pred_disp.shape[0] // 2
                    pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1])

                pred_disps.append(pred_disp)

        pred_disps = np.concatenate(pred_disps)

    else:
        # Load predictions from file
        print("-> Loading predictions from {}".format(opt.ext_disp_to_eval))
        pred_disps = np.load(opt.ext_disp_to_eval)

        if opt.eval_eigen_to_benchmark:
            eigen_to_benchmark_ids = np.load(
                os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy"))

            pred_disps = pred_disps[eigen_to_benchmark_ids]

    if opt.save_pred_disps:
        output_path = os.path.join(
            opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split))
        print("-> Saving predicted disparities to ", output_path)
        np.save(output_path, pred_disps)

    if opt.no_eval:
        print("-> Evaluation disabled. Done.")
        quit()

    elif opt.eval_split == 'benchmark':
        save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions")
        print("-> Saving out benchmark predictions to {}".format(save_dir))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx in range(len(pred_disps)):
            disp_resized = cv2.resize(pred_disps[idx], (1216, 352))
            depth = STEREO_SCALE_FACTOR / disp_resized
            depth = np.clip(depth, 0, 80)
            depth = np.uint16(depth * 256)
            save_path = os.path.join(save_dir, "{:010d}.png".format(idx))
            cv2.imwrite(save_path, depth)

        print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.")
        quit()
    if not opt.UseCustTest:
        gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
        gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1', allow_pickle = True)["data"]
    else:
        gt_depths = np.load("/media/shengjie/other/sceneUnderstanding/SDNET/splits/eigen_test_toy/gt_depths.npz", fix_imports=True, encoding='latin1', allow_pickle=True)["data"]

    print("-> Evaluating")

    if opt.eval_stereo:
        print("   Stereo evaluation - "
              "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR))
        opt.disable_median_scaling = True
        opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR
        if opt.EnableMedianScaleInEval:
            opt.disable_median_scaling = False
    else:
        print("   Mono evaluation - using median scaling")

    errors = []
    ratios = []

    for i in range(pred_disps.shape[0]):

        gt_depth = gt_depths[i]
        gt_height, gt_width = gt_depth.shape[:2]

        pred_disp = pred_disps[i]
        pred_disp = cv2.resize(pred_disp, (gt_width, gt_height))
        pred_depth = 1 / pred_disp

        # Some check:
        # with open('recompare.p', 'rb') as handle:
        #     train_outputs = pickle.load(handle)
        #     calib_dir = '/media/shengjie/other/sceneUnderstanding/monodepth2/kitti_data/kitti_raw/2011_09_26'
        #     velo_filename = '/media/shengjie/other/sceneUnderstanding/monodepth2/kitti_data/kitti_raw/2011_09_26/2011_09_26_drive_0002_sync/velodyne_points/data/0000000069.bin'
        #     gt_depth2 = kitti_utils.generate_depth_map(calib_dir, velo_filename, 2, True)
        #
        #     np.mean(np.abs(train_outputs['depth_gt'][0,0,:,:].cpu().numpy() - gt_depth))
        #     np.mean(np.abs(train_outputs['depth_pred'][0, 0, :, :].cpu().numpy() - pred_depth))
        #     pred_depth = pred_depth * train_outputs['scaleRation'].cpu().numpy()
        #
        #     train_depth = F.interpolate(train_outputs[('depth', 0, 1)], [gt_height, gt_width], mode='bilinear', align_corners=True)
        #     np.mean(np.abs(train_depth[0,0,:,:].cpu().numpy() - pred_depth))

        if opt.eval_split == "eigen" or opt.UseCustTest:
            mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)

            crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height,
                             0.03594771 * gt_width,  0.96405229 * gt_width]).astype(np.int32)
            crop_mask = np.zeros(mask.shape)
            crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
            mask = np.logical_and(mask, crop_mask)

        else:
            mask = gt_depth > 0

        # Some check:
        # with open('recompare.p', 'rb') as handle:
        #     eval_outputs = pickle.load(handle)
        #     np.mean(np.abs(eval_outputs['depth_gt'][0,0,:,:].cpu().numpy() -gt_depth ))

        pred_depth = pred_depth[mask]
        gt_depth = gt_depth[mask]

        pred_depth *= opt.pred_depth_scale_factor
        if not opt.disable_median_scaling:
            ratio = np.median(gt_depth) / np.median(pred_depth)
            ratios.append(ratio)
            pred_depth *= ratio

        pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
        pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH

        errors.append(compute_errors(gt_depth, pred_depth, UseGtMedianScaling = (opt.UseGtMedianScaling == True)))

    if not opt.disable_median_scaling:
        ratios = np.array(ratios)
        med = np.median(ratios)
        print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med)))

    mean_errors = np.array(errors).mean(0)

    print("\n  " + ("{:>8} | " * 8).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3", "abs_shift"))
    print(("&{: 8.3f}  " * 8).format(*mean_errors.tolist()) + "\\\\")
    print("\n-> Done!")
    if opt.isCudaMorphing and opt.borderMorphLoss:
        bnmorph.print_params()
Example #6
0
                                 help="which split to run eval on")
        self.parser.add_argument("--save_pred_disps",
                                 help="if set saves predicted disparities",
                                 action="store_true")

    def parse(self):
        self.options = self.parser.parse_args()
        return self.options


if __name__ == "__main__":
    options = MonodepthOptions()
    opt = options.parse()

    if opt.isCudaMorphing and opt.borderMorphLoss:
        bnmorph = BNMorph(height=opt.height, width=opt.width,
                          sparsityRad=2).cuda()

    outputdir = opt.outputdir
    if not os.path.isdir(outputdir):
        os.mkdir(outputdir)
    splits_dir = os.path.join(os.path.dirname(__file__), "splits")
    opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)

    assert os.path.isdir(opt.load_weights_folder), \
        "Cannot find a folder at {}".format(opt.load_weights_folder)

    print("-> Loading weights from {}".format(opt.load_weights_folder))

    filenames = readlines(
        os.path.join(splits_dir, opt.split, "train_files.txt"))
    # filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt"))
Example #7
0
class Morph_semantics():
    def __init__(self, height, width):
        self.height = height
        self.width = width
        self.tool = grad_computation_tools(batch_size=1, height=self.height,
                                           width=self.width).cuda()
        self.auto_morph = BNMorph(height=self.height, width=self.width, senseRange=20).cuda()
        self.tool.disparityTh = 0.07
        self.dismetric_seman = list()
        self.dismetric_disp = list()

    def compute_edge_distance(self, depth, semantics, semantics_gt, isdisp = True):
        depth = depth.cuda()
        if not isdisp:
            depth = 1 / depth
        else:
            depth = depth.float()
        semantics = semantics.cuda()
        foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18]
        batch_size = semantics.shape[0]
        height = semantics.shape[2]
        width = semantics.shape[3]
        foregroundMapGt = torch.ones([batch_size, 1, height, width],
                                     dtype=torch.uint8, device=torch.device("cuda"))
        for m in foregroundType:
            foregroundMapGt = foregroundMapGt * (semantics != m)
        foregroundMapGt = (1 - foregroundMapGt).float()

        disparity_grad = torch.abs(self.tool.convDispx(depth)) + torch.abs(
            self.tool.convDispy(depth))
        semantics_grad = torch.abs(self.tool.convDispx(foregroundMapGt)) + torch.abs(
            self.tool.convDispy(foregroundMapGt))
        disparity_grad = disparity_grad * self.tool.zero_mask
        semantics_grad = semantics_grad * self.tool.zero_mask


        if not isdisp:
            disparity_grad_bin = disparity_grad > self.tool.disparityTh
        else:
            disparity_grad_bin = disparity_grad > self.tool.disparityTh * 150
        semantics_grad_bin = semantics_grad > self.tool.semanticsTh

        semantics_gt = semantics_gt.cuda()
        foregroundMapGt = torch.ones([batch_size, 1, height, width],
                                     dtype=torch.uint8, device=torch.device("cuda"))
        for m in foregroundType:
            foregroundMapGt = foregroundMapGt * (semantics_gt != m)
        foregroundMapGt = (1 - foregroundMapGt).float()
        semantics_grad = torch.abs(self.tool.convDispx(foregroundMapGt)) + torch.abs(
            self.tool.convDispy(foregroundMapGt))
        semantics_grad = semantics_grad * self.tool.zero_mask
        semantics_grad_bin_gt = semantics_grad > self.tool.semanticsTh

        # tensor2disp(semantics_grad_bin, vmax=1, ind=0).show()
        # tensor2disp(semantics_grad_bin_gt, vmax=1, ind=0).show()
        morphedx, morphedy, ocoeff_seman = self.auto_morph.find_corresponding_pts(semantics_grad_bin, semantics_grad_bin_gt)
        dists = (ocoeff_seman['orgpts_x'] - ocoeff_seman['correspts_x']) * (ocoeff_seman['orgpts_x'] - ocoeff_seman['correspts_x']) + (ocoeff_seman['orgpts_y'] - ocoeff_seman['correspts_y'])  * (ocoeff_seman['orgpts_y'] - ocoeff_seman['correspts_y'])
        dists = dists[ocoeff_seman['orgpts_x'] > 1e-3]
        self.dismetric_seman = self.dismetric_seman + list(dists.cpu())
        morphedx, morphedy, ocoeff_disp = self.auto_morph.find_corresponding_pts(disparity_grad_bin, semantics_grad_bin_gt)
        distc = (ocoeff_disp['orgpts_x'] - ocoeff_disp['correspts_x']) * (ocoeff_disp['orgpts_x'] - ocoeff_disp['correspts_x']) + (ocoeff_disp['orgpts_y'] - ocoeff_disp['correspts_y'])  * (ocoeff_disp['orgpts_y'] - ocoeff_disp['correspts_y'])
        distc = distc[ocoeff_disp['orgpts_x'] > 1e-3]
        self.dismetric_disp = self.dismetric_disp + list(distc.cpu())

    def show_dis_comp(self):
        seman = np.array(self.dismetric_seman).mean()
        disp = np.array(self.dismetric_disp).mean()
        print("semantic %f, disp %f" % (seman, disp))
    def morh_semantics(self, depth, semantics, isdisp = True):
        depth = depth.cuda()
        if not isdisp:
            depth = 1 / depth
        else:
            depth = depth.float()
        semantics = semantics.cuda()
        foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18]
        batch_size = semantics.shape[0]
        height = semantics.shape[2]
        width = semantics.shape[3]
        foregroundMapGt = torch.ones([batch_size, 1, height, width],
                                     dtype=torch.uint8, device=torch.device("cuda"))
        for m in foregroundType:
            foregroundMapGt = foregroundMapGt * (semantics != m)
        foregroundMapGt = (1 - foregroundMapGt).float()

        disparity_grad = torch.abs(self.tool.convDispx(depth)) + torch.abs(
            self.tool.convDispy(depth))
        semantics_grad = torch.abs(self.tool.convDispx(foregroundMapGt)) + torch.abs(
            self.tool.convDispy(foregroundMapGt))
        disparity_grad = disparity_grad * self.tool.zero_mask
        semantics_grad = semantics_grad * self.tool.zero_mask


        if not isdisp:
            disparity_grad_bin = disparity_grad > self.tool.disparityTh
        else:
            disparity_grad_bin = disparity_grad > self.tool.disparityTh * 150
        semantics_grad_bin = semantics_grad > self.tool.semanticsTh
        # tensor2disp(disparity_grad, percentile=95, ind=0).show()
        # tensor2disp(disparity_grad, ind=0, vmax=0.1).show()
        # tensor2disp(disparity_grad_bin, ind=0, vmax=1).show()
        morphedx, morphedy, ocoeff = self.auto_morph.find_corresponding_pts(semantics_grad_bin, disparity_grad_bin)
        # fig_seman = tensor2disp(semantics_grad_bin, ind=0, vmax=1)
        # morphedx, morphedy, ocoeff = self.auto_morph.find_corresponding_pts_debug(semantics_grad_bin, disparity_grad_bin, disparityMap=depth, semantic_figure = tensor2semantic(semantics, ind=0))
        morphedx = (morphedx / (self.width - 1) - 0.5) * 2
        morphedy = (morphedy / (self.height - 1) - 0.5) * 2
        grid = torch.cat([morphedx, morphedy], dim=1).permute(0, 2, 3, 1)
        seman_morphed = F.grid_sample(semantics.detach().float(), grid, mode = 'nearest', padding_mode="border")
        # tensor2semantic(seman_morphed, ind=0).show()
        # tensor2semantic(semantics, ind=0).show()
        # joint = torch.zeros([height, width, 3])
        # joint[:,:,0] = disparity_grad_bin[0,0,:,:] * 255
        # joint[:, :, 1] = semantics_grad_bin[0, 0, :, :] * 255
        # joint = joint.cpu().numpy().astype(np.uint8)
        # pil.fromarray(joint).show()



        # foregroundMapGt = torch.ones([batch_size, 1, height, width],
        #                              dtype=torch.uint8, device=torch.device("cuda"))
        # for m in foregroundType:
        #     foregroundMapGt = foregroundMapGt * (seman_morphed != m)
        # foregroundMapGt = (1 - foregroundMapGt).float()
        # semantics_grad_morphed = torch.abs(self.tool.convDispx(foregroundMapGt)) + torch.abs(
        #     self.tool.convDispy(foregroundMapGt)) * self.tool.zero_mask
        # semantics_grad_bin_morphed = semantics_grad_morphed > self.tool.semanticsTh
        # joint = torch.zeros([height, width, 3])
        # joint[:,:,0] = disparity_grad_bin[0,0,:,:] * 255
        # joint[:, :, 1] = semantics_grad_bin_morphed[0, 0, :, :] * 255
        # joint = joint.cpu().numpy().astype(np.uint8)
        # pil.fromarray(joint).show()
        #
        # selector = ocoeff['orgpts_x'] != -1
        # srcptsx = ocoeff['orgpts_x'][selector]
        # srcptsy = ocoeff['orgpts_y'][selector]
        # dstPtsx = ocoeff['correspts_x'][selector]
        # dstPtsy = ocoeff['correspts_y'][selector]
        # plt.figure()
        # plt.imshow(fig_seman)
        # for i in range(0, srcptsx.shape[0]):
        #     plt.plot([srcptsx[i], dstPtsx[i]], [srcptsy[i], dstPtsy[i]])
        # plt.show()
        # tensor2semantic(seman_morphed, ind=0).show()
        # tensor2semantic(semantics, ind=0).show()

        return seman_morphed
Example #8
0
def evaluate(opt):
    """Evaluates a pretrained model using a specified test set
    """
    MIN_DEPTH = 1e-3
    MAX_DEPTH = 80
    viewPythonVer = False
    viewCudaVer = True

    if viewCudaVer:
        bnmorph = BNMorph(height=opt.height, width=opt.width).cuda()

    opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)

    assert os.path.isdir(opt.load_weights_folder), \
        "Cannot find a folder at {}".format(opt.load_weights_folder)

    print("-> Loading weights from {}".format(opt.load_weights_folder))

    filenames = readlines(os.path.join(splits_dir, opt.split, "val_files.txt"))
    encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth")
    decoder_path = os.path.join(opt.load_weights_folder, "depth.pth")

    encoder_dict = torch.load(encoder_path)

    if opt.use_stereo:
        opt.frame_ids.append("s")
    if opt.dataset == 'cityscape':
        dataset = datasets.CITYSCAPERawDataset(
            opt.data_path,
            filenames,
            opt.height,
            opt.width,
            opt.frame_ids,
            4,
            is_train=False,
            tag=opt.dataset,
            load_meta=True,
            direction_left=opt.direction_left)
    elif opt.dataset == 'kitti':
        dataset = datasets.KITTIRAWDataset(
            opt.data_path,
            filenames,
            opt.height,
            opt.width,
            opt.frame_ids,
            4,
            is_train=False,
            tag=opt.dataset,
            is_load_semantics=opt.use_kitti_gt_semantics,
            is_predicted_semantics=opt.is_predicted_semantics,
            direction_left=opt.direction_left)
    else:
        raise ValueError("No predefined dataset")
    dataloader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.num_workers,
                            pin_memory=True,
                            drop_last=True)

    encoder = networks.ResnetEncoder(opt.num_layers, False, num_input_images=2)
    if opt.switchMode == 'on':
        depth_decoder = networks.DepthDecoder(
            encoder.num_ch_enc,
            isSwitch=True,
            isMulChannel=opt.isMulChannel,
            outputtwoimage=(opt.outputtwoimage == True))
    else:
        depth_decoder = networks.DepthDecoder(encoder.num_ch_enc)

    model_dict = encoder.state_dict()
    encoder.load_state_dict(
        {k: v
         for k, v in encoder_dict.items() if k in model_dict})
    depth_decoder.load_state_dict(torch.load(decoder_path))

    encoder.cuda()
    encoder.eval()
    depth_decoder.cuda()
    depth_decoder.eval()

    viewIndex = 0
    tool = grad_computation_tools(batch_size=opt.batch_size,
                                  height=opt.height,
                                  width=opt.width).cuda()
    auto_morph = AutoMorph(height=opt.height, width=opt.width)
    with torch.no_grad():
        for idx, inputs in enumerate(dataloader):
            for key, ipt in inputs.items():
                if not (key == 'height' or key == 'width' or key == 'tag'
                        or key == 'cts_meta' or key == 'file_add'):
                    inputs[key] = ipt.to(torch.device("cuda"))

            input_color = torch.cat(
                [inputs[("color_aug", 0, 0)], inputs[("color_aug", 's', 0)]],
                dim=1).cuda()
            # input_color = inputs[("color", 0, 0)].cuda()
            # tensor2rgb(inputs[("color_aug", 0, 0)], ind=0).show()
            # tensor2rgb(inputs[("color_aug", 's', 0)], ind=0).show()
            features = encoder(input_color)
            outputs = dict()
            outputs.update(
                depth_decoder(features,
                              computeSemantic=True,
                              computeDepth=False))
            outputs.update(
                depth_decoder(features,
                              computeSemantic=False,
                              computeDepth=True))

            if not opt.view_right:
                disparityMap = outputs[('mul_disp', 0)][:, 0:1, :, :]
            else:
                disparityMap = outputs[('mul_disp', 0)][:, 1:2, :, :]
            depthMap = torch.clamp(disparityMap, max=80)
            fig_seman = tensor2semantic(inputs['seman_gt'],
                                        ind=viewIndex,
                                        isGt=True)
            fig_rgb = tensor2rgb(inputs[('color', 0, 0)], ind=viewIndex)
            fig_disp = tensor2disp(disparityMap, ind=viewIndex, vmax=0.1)

            segmentationMapGt = inputs['seman_gt']
            foregroundType = [
                5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18
            ]  # pole, traffic light, traffic sign, person, rider, car, truck, bus, train, motorcycle, bicycle
            foregroundMapGt = torch.ones(disparityMap.shape).cuda().byte()
            for m in foregroundType:
                foregroundMapGt = foregroundMapGt * (segmentationMapGt != m)
            foregroundMapGt = (1 - foregroundMapGt).float()

            disparity_grad = torch.abs(
                tool.convDispx(disparityMap)) + torch.abs(
                    tool.convDispy(disparityMap))
            semantics_grad = torch.abs(
                tool.convDispx(foregroundMapGt)) + torch.abs(
                    tool.convDispy(foregroundMapGt))
            disparity_grad = disparity_grad * tool.zero_mask
            semantics_grad = semantics_grad * tool.zero_mask

            disparity_grad_bin = disparity_grad > tool.disparityTh
            semantics_grad_bin = semantics_grad > tool.semanticsTh

            # tensor2disp(disparity_grad_bin, ind=viewIndex, vmax=1).show()
            # tensor2disp(semantics_grad_bin, ind=viewIndex, vmax=1).show()

            if viewPythonVer:
                disparity_grad_bin = disparity_grad_bin.detach().cpu().numpy()
                semantics_grad_bin = semantics_grad_bin.detach().cpu().numpy()

                disparityMap_to_processed = disparityMap.detach().cpu().numpy(
                )[viewIndex, 0, :, :]
                dispMap_morphed, dispMap_morphRec = auto_morph.automorph(
                    disparity_grad_bin[viewIndex, 0, :, :],
                    semantics_grad_bin[viewIndex,
                                       0, :, :], disparityMap_to_processed)

                fig_disp_processed = visualizeNpDisp(dispMap_morphed, vmax=0.1)
                overlay_processed = pil.fromarray(
                    (np.array(fig_disp_processed) * 0.7 +
                     np.array(fig_seman) * 0.3).astype(np.uint8))
                overlay_org = pil.fromarray(
                    (np.array(fig_disp) * 0.7 +
                     np.array(fig_seman) * 0.3).astype(np.uint8))
                combined_fig = pil.fromarray(
                    np.concatenate([
                        np.array(overlay_org),
                        np.array(overlay_processed),
                        np.array(fig_disp),
                        np.array(fig_disp_processed)
                    ],
                                   axis=0))
                combined_fig.save(
                    "/media/shengjie/other/sceneUnderstanding/Stereo_SDNET/visualization/border_morph_l2_3/"
                    + str(idx) + ".png")
            if viewCudaVer:
                # morphedx, morphedy = bnmorph.find_corresponding_pts(disparity_grad_bin, semantics_grad_bin, disparityMap, fig_seman, 10)
                # morphedx = (morphedx / (opt.width - 1) - 0.5) * 2
                # morphedy = (morphedy / (opt.height - 1) - 0.5) * 2
                # grid = torch.cat([morphedx, morphedy], dim = 1).permute(0,2,3,1)
                # disparityMap_morphed = F.grid_sample(disparityMap, grid, padding_mode="border")
                # fig_morphed = tensor2disp(disparityMap_morphed, vmax=0.08, ind=0)
                # fig_disp = tensor2disp(disparityMap, vmax=0.08, ind=0)
                # fig_combined = pil.fromarray(np.concatenate([np.array(fig_morphed), np.array(fig_disp)], axis=0))
                # fig_combined.show()
                svpath = os.path.join(opt.load_weights_folder).split('/')
                try:
                    svpath = os.path.join(
                        "/media/shengjie/other/sceneUnderstanding/Stereo_SDNET/visualization",
                        svpath[-3])
                    os.mkdir(svpath)
                except FileExistsError:
                    a = 1
                morphedx, morphedy, coeff = bnmorph.find_corresponding_pts(
                    disparity_grad_bin, semantics_grad_bin)
                morphedx = (morphedx / (opt.width - 1) - 0.5) * 2
                morphedy = (morphedy / (opt.height - 1) - 0.5) * 2
                grid = torch.cat([morphedx, morphedy],
                                 dim=1).permute(0, 2, 3, 1)
                disparityMap_morphed = F.grid_sample(disparityMap,
                                                     grid,
                                                     padding_mode="border")

                fig_morphed = tensor2disp(disparityMap_morphed,
                                          vmax=0.08,
                                          ind=0)
                fig_disp = tensor2disp(disparityMap, vmax=0.08, ind=0)
                fig_morphed_overlayed = pil.fromarray(
                    (np.array(fig_seman) * 0.5 +
                     np.array(fig_morphed) * 0.5).astype(np.uint8))
                fig_disp_overlayed = pil.fromarray(
                    (np.array(fig_seman) * 0.5 +
                     np.array(fig_disp) * 0.5).astype(np.uint8))
                # fig_rgb =  tensor2rgb(inputs[("color", 0, 0)], ind=0)
                # fig_combined = pil.fromarray(np.concatenate([np.array(fig_disp_overlayed), np.array(fig_morphed_overlayed), np.array(fig_disp), np.array(fig_morphed), np.array(fig_rgb)], axis=0))
                fig_combined = pil.fromarray(
                    np.concatenate([
                        np.array(fig_disp_overlayed),
                        np.array(fig_morphed_overlayed),
                        np.array(fig_disp),
                        np.array(fig_morphed)
                    ],
                                   axis=0))
                fig_combined.save(os.path.join(svpath, str(idx) + ".png"))
Example #9
0
class Morph_semantics():
    def __init__(self, height, width):
        self.height = height
        self.width = width
        self.tool = grad_computation_tools(batch_size=1,
                                           height=self.height,
                                           width=self.width).cuda()
        self.auto_morph = BNMorph(height=self.height,
                                  width=self.width,
                                  senseRange=20).cuda()
        self.tool.disparityTh = 0.07

    def morh_semantics(self, depth, semantics):
        depth = depth.cuda()
        depth = 1 / depth
        semantics = semantics.cuda()
        foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18]
        batch_size = semantics.shape[0]
        height = semantics.shape[2]
        width = semantics.shape[3]
        foregroundMapGt = torch.ones([batch_size, 1, height, width],
                                     dtype=torch.uint8,
                                     device=torch.device("cuda"))
        for m in foregroundType:
            foregroundMapGt = foregroundMapGt * (semantics != m)
        foregroundMapGt = (1 - foregroundMapGt).float()

        disparity_grad = torch.abs(self.tool.convDispx(depth)) + torch.abs(
            self.tool.convDispy(depth))
        semantics_grad = torch.abs(
            self.tool.convDispx(foregroundMapGt)) + torch.abs(
                self.tool.convDispy(foregroundMapGt))
        disparity_grad = disparity_grad * self.tool.zero_mask
        semantics_grad = semantics_grad * self.tool.zero_mask

        disparity_grad_bin = disparity_grad > self.tool.disparityTh
        semantics_grad_bin = semantics_grad > self.tool.semanticsTh
        # tensor2disp(disparity_grad, ind=0, vmax=0.1).show()
        # tensor2disp(disparity_grad_bin, ind=0, vmax=1).show()
        # fig_seman = tensor2disp(semantics_grad_bin, ind=0, vmax=1)
        morphedx, morphedy, ocoeff = self.auto_morph.find_corresponding_pts(
            semantics_grad_bin, disparity_grad_bin, pixel_distance_weight=20)
        morphedx = (morphedx / (self.width - 1) - 0.5) * 2
        morphedy = (morphedy / (self.height - 1) - 0.5) * 2
        grid = torch.cat([morphedx, morphedy], dim=1).permute(0, 2, 3, 1)
        seman_morphed = F.grid_sample(semantics.detach().float(),
                                      grid,
                                      mode='nearest',
                                      padding_mode="border")
        # tensor2semantic(seman_morphed, ind=0).show()
        # tensor2semantic(semantics, ind=0).show()
        # joint = torch.zeros([height, width, 3])
        # joint[:,:,0] = disparity_grad_bin[0,0,:,:] * 255
        # joint[:, :, 1] = semantics_grad_bin[0, 0, :, :] * 255
        # joint = joint.cpu().numpy().astype(np.uint8)
        # pil.fromarray(joint).show()

        # foregroundMapGt = torch.ones([batch_size, 1, height, width],
        #                              dtype=torch.uint8, device=torch.device("cuda"))
        # for m in foregroundType:
        #     foregroundMapGt = foregroundMapGt * (seman_morphed != m)
        # foregroundMapGt = (1 - foregroundMapGt).float()
        # semantics_grad_morphed = torch.abs(self.tool.convDispx(foregroundMapGt)) + torch.abs(
        #     self.tool.convDispy(foregroundMapGt)) * self.tool.zero_mask
        # semantics_grad_bin_morphed = semantics_grad_morphed > self.tool.semanticsTh
        # joint = torch.zeros([height, width, 3])
        # joint[:,:,0] = disparity_grad_bin[0,0,:,:] * 255
        # joint[:, :, 1] = semantics_grad_bin_morphed[0, 0, :, :] * 255
        # joint = joint.cpu().numpy().astype(np.uint8)
        # pil.fromarray(joint).show()
        #
        # selector = ocoeff['orgpts_x'] != -1
        # srcptsx = ocoeff['orgpts_x'][selector]
        # srcptsy = ocoeff['orgpts_y'][selector]
        # dstPtsx = ocoeff['correspts_x'][selector]
        # dstPtsy = ocoeff['correspts_y'][selector]
        # plt.figure()
        # plt.imshow(fig_seman)
        # for i in range(0, srcptsx.shape[0]):
        #     plt.plot([srcptsx[i], dstPtsx[i]], [srcptsy[i], dstPtsy[i]])
        # plt.show()
        return seman_morphed
Example #10
0
class Trainer:
    def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)

        # checking height and width are multiples of 32
        assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
        assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"
        self.STEREO_SCALE_FACTOR = 5.4

        if self.opt.switchMode == 'on':
            self.switchMode = True
        else:
            self.switchMode = False
        self.models = {}
        self.parameters_to_train = []

        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        self.num_scales = len(self.opt.scales)
        self.num_input_frames = len(self.opt.frame_ids)
        self.semanticCoeff = self.opt.semanticCoeff
        self.sfx = nn.Softmax()
        assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"
        if self.opt.use_stereo:
            self.opt.frame_ids.append("s")
        self.models["encoder"] = networks.ResnetEncoder(
            self.opt.num_layers, self.opt.weights_init == "pretrained")
        self.models["encoder"].to(self.device)
        self.parameters_to_train += list(self.models["encoder"].parameters())
        self.models["depth"] = networks.DepthDecoder(
            self.models["encoder"].num_ch_enc, self.opt.scales)
        self.models["depth"].to(self.device)

        self.parameters_to_train += list(self.models["depth"].parameters())
        self.model_optimizer = optim.Adam(self.parameters_to_train,
                                          self.opt.learning_rate)
        self.model_lr_scheduler = optim.lr_scheduler.StepLR(
            self.model_optimizer, self.opt.scheduler_step_size, 0.1)

        self.morph_optimizer = optim.SGD(self.parameters_to_train,
                                         self.opt.learning_rate)

        print("Training model named:\n  ", self.opt.model_name)
        print("Models and tensorboard events files are saved to:\n  ",
              self.opt.log_dir)
        print("Training is using:\n  ", self.device)

        self.set_dataset()
        self.writers = {}
        for mode in ["train", "val"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))
        if not self.opt.no_ssim:
            self.ssim = SSIM()
            self.ssim.to(self.device)

        self.set_layers()
        self.depth_metric_names = [
            "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1",
            "da/a2", "da/a3"
        ]

        print("Using split:\n  ", self.opt.split)
        print("Switch mode on") if self.switchMode else print(
            "Switch mode off")
        print(
            "There are {:d} training items and {:d} validation items\n".format(
                self.train_num, self.val_num))

        if self.opt.load_weights_folder is not None:
            self.load_model()
        self.save_opts()

        self.sl1 = torch.nn.SmoothL1Loss()
        self.mp2d = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.disp_range = np.arange(0, 150, 1)
        self.bins = np.zeros(len(self.disp_range) - 1)
        self.deptherrRec = np.zeros(7)
        self.tot_rec = 0

    def set_layers(self):
        """properly handle layer initialization under multiple dataset situation
        """
        self.backproject_depth = {}
        self.project_3d = {}
        if self.opt.selfocclu:
            self.selfOccluMask = SelfOccluMask().cuda()

        for n, scale in enumerate(self.opt.scales):
            h = self.opt.height // (2**scale)
            w = self.opt.width // (2**scale)

            self.backproject_depth[scale] = BackprojectDepth(
                self.opt.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
            self.project_3d[scale].to(self.device)

        if self.opt.bnMorphLoss:
            from bnmorph.bnmorph import BNMorph
            self.tool = grad_computation_tools(batch_size=self.opt.batch_size,
                                               height=self.opt.height,
                                               width=self.opt.width).cuda()

            self.auto_morph = BNMorph(height=self.opt.height,
                                      width=self.opt.width,
                                      senseRange=20).cuda()
            self.textureMeasure = TextureIndicatorM().cuda()

    def set_dataset(self):
        fpath = os.path.join(os.path.dirname(__file__), "splits",
                             self.opt.split, "{}_files.txt")
        train_filenames = readlines(fpath.format("train"))
        val_filenames = readlines(fpath.format("val"))

        train_dataset = datasets.KITTIRAWDataset(
            self.opt.data_path,
            train_filenames,
            self.opt.height,
            self.opt.width,
            self.opt.frame_ids,
            4,
            is_train=True,
            load_meta=self.opt.load_meta,
            is_load_semantics=True,
            is_predicted_semantics=self.opt.is_predicted_semantics,
            load_morphed_depth=self.opt.load_morphed_depth,
            read_stereo=self.opt.read_stereo,
            stereo_meta=self.opt.SGMStereo_prediction_folder,
            morphFolder=self.opt.read_processed_results_path)
        val_dataset = datasets.KITTIRAWDataset(
            self.opt.data_path,
            val_filenames,
            self.opt.height,
            self.opt.width,
            self.opt.frame_ids,
            4,
            is_train=False,
            load_meta=self.opt.load_meta,
            is_load_semantics=True,
            read_stereo=self.opt.read_stereo,
            stereo_meta=self.opt.SGMStereo_prediction_folder,
            is_predicted_semantics=self.opt.is_predicted_semantics)

        self.train_loader = DataLoader(train_dataset,
                                       self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)
        self.val_loader = DataLoader(val_dataset,
                                     self.opt.batch_size,
                                     shuffle=True,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)
        self.val_iter = iter(self.val_loader)

        self.train_num = train_dataset.__len__()
        self.val_num = val_dataset.__len__()
        self.num_total_steps = self.train_num // self.opt.batch_size * self.opt.num_epochs

    def set_train(self):
        """Convert all models to training mode
        """
        for m in self.models.values():
            m.train()

    def set_eval(self):
        """Convert all models to testing/evaluation mode
        """
        for m in self.models.values():
            m.eval()

    def train(self):
        """Run the entire training pipeline
        """
        self.epoch = 0
        self.step = 0
        self.start_time = time.time()
        for self.epoch in range(self.opt.num_epochs):
            self.run_epoch()
            if (self.epoch + 1) % self.opt.save_frequency == 0:
                self.save_model()

    def supervised_with_morph(self, inputs):
        if not self.opt.inline_finetune:
            outputs = dict()
            losses = dict()
            for key, ipt in inputs.items():
                if not (key == 'height' or key == 'width' or key == 'tag'
                        or key == 'cts_meta' or key == 'file_add'):
                    inputs[key] = ipt.to(self.device)
            features = self.models["encoder"](inputs["color_aug", 0, 0])
            outputs.update(self.models["depth"](features,
                                                computeSemantic=False,
                                                computeDepth=True))

            diffMap = (outputs['disp', 0] - inputs['depth_morphed'])**2
            losses['totLoss'] = torch.mean(diffMap) * 1e3
            losses["similarity_loss"] = losses["totLoss"]

            self.morph_optimizer.zero_grad()
            losses['totLoss'].backward()
            self.morph_optimizer.step()
        else:
            outputs, losses = self.process_batch(inputs)

            stable_disp = outputs['disp', 0].detach()
            disparity_grad_bin = self.tool.get_disparityEdge(outputs['disp',
                                                                     0])
            semantics_grad_bin = self.tool.get_semanticsEdge(
                inputs['seman_gt'])

            morphedx, morphedy, ocoeff = self.auto_morph.find_corresponding_pts(
                disparity_grad_bin, semantics_grad_bin)
            morphedx = (morphedx / (self.opt.width - 1) - 0.5) * 2
            morphedy = (morphedy / (self.opt.height - 1) - 0.5) * 2
            grid = torch.cat([morphedx, morphedy], dim=1).permute(0, 2, 3, 1)
            dispMaps_morphed = F.grid_sample(stable_disp,
                                             grid,
                                             padding_mode="border")
            outputs['dispMaps_morphed'] = dispMaps_morphed
            ssim_morph = self.compute_reprojection_loss(
                dispMaps_morphed, outputs['disp', 0])

            if not self.opt.use_ssim_compare_mask:
                kth_val, kth_ind = torch.kthvalue(ssim_morph.cpu().view(
                    self.opt.batch_size, 1, -1),
                                                  dim=2,
                                                  k=self.topk_kval)
                kth_val = kth_val.cuda()
                selector_mask = (ssim_morph > kth_val.view(-1, 1, 1, 1).expand(
                    -1, 1, self.opt.height, self.opt.width)).float()
                losses["similarity_loss"] = torch.sum(
                    ssim_morph * selector_mask * outputs['grad_proj_msak'] *
                    (1 - outputs['ssimMask'])) / (torch.sum(selector_mask) + 1)
                losses['totLoss'] = losses[
                    "similarity_loss"] * self.opt.l1_weight + losses['totLoss']
            else:
                with torch.no_grad():
                    th = 1.05
                    ssim_val_predict = self.compute_reprojection_loss(
                        outputs[('color', 's', 0)], inputs[('color', 0, 0)])
                    scaledDisp, depth = disp_to_depth(dispMaps_morphed,
                                                      self.opt.min_depth,
                                                      self.opt.max_depth)
                    frame_id = "s"
                    T = inputs["stereo_T"]
                    cam_points = self.backproject_depth[0](depth,
                                                           inputs[("inv_K",
                                                                   0)])
                    pix_coords = self.project_3d[0](cam_points,
                                                    inputs[("K", 0)], T)
                    morphed_rgb = F.grid_sample(inputs[("color", frame_id, 0)],
                                                pix_coords,
                                                padding_mode="border")
                    ssim_val_morph = self.compute_reprojection_loss(
                        morphed_rgb, inputs[('color', 0, 0)])
                    if self.opt.is_stable_mask:
                        selector_mask = (
                            ssim_val_predict - th * ssim_val_morph >
                            0).float() * outputs['grad_proj_msak'] * (
                                1 - outputs['ssimMask'])
                    else:
                        selector_mask = (ssim_val_predict - th * ssim_val_morph
                                         >
                                         0).float() * outputs['grad_proj_msak']

                if self.opt.is_texture_weighted:
                    texture_measure = torch.mean(self.textureMeasure(
                        inputs[('color', 0, 0)]),
                                                 dim=1,
                                                 keepdim=True)
                    losses["similarity_loss"] = 100 * torch.sum(
                        torch.log(1 + torch.abs(dispMaps_morphed - outputs[
                            'disp', 0]) * texture_measure) *
                        selector_mask) / (torch.sum(selector_mask) + 1)

                losses['totLoss'] = losses[
                    "similarity_loss"] * self.opt.l1_weight + losses['totLoss']

                self.model_optimizer.zero_grad()
                losses['totLoss'].backward()
                self.model_optimizer.step()

        return outputs, losses

    def run_epoch(self):
        """Run a single epoch of training and validation
        """
        self.model_lr_scheduler.step()
        self.set_train()

        for batch_idx, inputs in enumerate(self.train_loader):

            before_op_time = time.time()

            outputs, losses = self.supervised_with_morph(inputs)

            duration = time.time() - before_op_time

            # log less frequently after the first 2000 steps to save time & disk space
            early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 1
            late_phase = self.step % self.opt.val_frequency == 0

            if early_phase or late_phase:
                if "loss_depth" in losses:
                    loss_depth = losses["loss_depth"].cpu().data
                else:
                    loss_depth = -1

                self.log_time(batch_idx, duration, loss_depth,
                              losses["totLoss"])
                if self.step % self.opt.val_frequency == 0:
                    if "depth_gt" in inputs and ('depth', 0, 0) in outputs:
                        self.compute_depth_losses(inputs, outputs, losses)

                self.log("train", inputs, outputs, losses, writeImage=False)

                if self.step % self.opt.val_frequency == 0:
                    self.val()
                    if self.opt.writeImg:
                        if 'dispMaps_morphed' in outputs:
                            self.record_img(
                                disp=outputs['disp', 0],
                                semantic_gt=inputs['seman_gt'],
                                disp_morphed=outputs['dispMaps_morphed'],
                                mask=outputs[('depth_hint_sel', 0)])
                        else:
                            self.record_img(disp=outputs['disp', 0],
                                            semantic_gt=inputs['seman_gt'])
            self.step += 1

    def process_batch(self, inputs):
        """Pass a minibatch through the network and generate images and losses
        """
        for key, ipt in inputs.items():
            if not (key == 'height' or key == 'width' or key == 'tag'
                    or key == 'cts_meta' or key == 'file_add'):
                inputs[key] = ipt.to(self.device)

        features = self.models["encoder"](inputs["color_aug", 0, 0])
        outputs = dict()
        outputs.update(self.models["depth"](features))
        self.generate_images_pred(inputs, outputs)
        losses = self.compute_losses(inputs, outputs)
        return outputs, losses

    def val(self):
        """Validate the model on a single minibatch
        """
        self.set_eval()
        try:
            inputs = self.val_iter.next()
        except StopIteration:
            self.val_iter = iter(self.val_loader)
            inputs = self.val_iter.next()

        with torch.no_grad():
            outputs, losses = self.process_batch(inputs)
            self.compute_depth_losses(inputs, outputs, losses)
            self.log("val", inputs, outputs, losses, self.opt.writeImg)
            del inputs, outputs, losses
        self.set_train()

    def generate_images_pred(self, inputs, outputs):
        """Generate the warped (reprojected) color images for a minibatch.
        Generated images are saved into the `outputs` dictionary.
        """
        height = self.opt.height
        width = self.opt.width
        source_scale = 0
        for scale in self.opt.scales:
            disp = outputs[("disp", scale)]
            disp = F.interpolate(disp, [height, width],
                                 mode="bilinear",
                                 align_corners=False)
            scaledDisp, depth = disp_to_depth(disp, self.opt.min_depth,
                                              self.opt.max_depth)

            frame_id = "s"
            T = inputs["stereo_T"]
            cam_points = self.backproject_depth[source_scale](
                depth, inputs[("inv_K", source_scale)])
            pix_coords = self.project_3d[source_scale](cam_points,
                                                       inputs[("K",
                                                               source_scale)],
                                                       T)

            outputs[("disp", scale)] = disp
            outputs[("depth", 0, scale)] = depth
            outputs[("sample", frame_id, scale)] = pix_coords
            outputs[("color", frame_id, scale)] = F.grid_sample(
                inputs[("color", frame_id, source_scale)],
                outputs[("sample", frame_id, scale)],
                padding_mode="border")

            if scale == 0:
                cam_points = self.backproject_depth[source_scale](
                    inputs['depth_hint'], inputs[("inv_K", source_scale)])
                pix_coords = self.project_3d[source_scale](
                    cam_points, inputs[("K", source_scale)], T)
                outputs[("color_depth_hint", frame_id,
                         scale)] = F.grid_sample(inputs[("color", frame_id,
                                                         source_scale)],
                                                 pix_coords,
                                                 padding_mode="border")

                outputs['grad_proj_msak'] = (
                    (pix_coords[:, :, :, 0] > -1) *
                    (pix_coords[:, :, :, 1] > -1) *
                    (pix_coords[:, :, :, 0] < 1) *
                    (pix_coords[:, :, :, 1] < 1)).unsqueeze(1).float()
                outputs[("real_scale_disp", scale)] = scaledDisp * (torch.abs(
                    inputs[("K", source_scale)][:, 0, 0] * T[:, 0, 3]).view(
                        self.opt.batch_size, 1, 1, 1).expand_as(scaledDisp))

    def compute_reprojection_loss(self, pred, target):
        """Computes reprojection loss between a batch of predicted and target images
        """

        l1_loss = torch.abs(target - pred).mean(1, True)
        ssim_loss = self.ssim(pred, target).mean(1, True)
        reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

        return reprojection_loss

    def compute_losses(self, inputs, outputs):
        """Compute the reprojection and smoothness losses for a minibatch
        """
        losses = {}
        losses["totLoss"] = 0

        source_scale = 0
        target = inputs[("color", 0, source_scale)]
        if self.opt.selfocclu:
            sourceSSIMMask = self.selfOccluMask(
                outputs[('real_scale_disp', source_scale)],
                inputs['stereo_T'][:, 0, 3])
        else:
            sourceSSIMMask = torch.zeros_like(outputs[('real_scale_disp',
                                                       source_scale)])
        outputs['ssimMask'] = sourceSSIMMask

        # compute depth hint reprojection loss
        if self.opt.read_stereo:
            pred = outputs[("color_depth_hint", 's', 0)]
            depth_hint_reproj_loss = self.compute_reprojection_loss(
                pred, inputs[("color", 0, 0)])
            depth_hint_reproj_loss += 1000 * (1 - inputs['depth_hint_mask'])
        else:
            depth_hint_reproj_loss = None

        for scale in self.opt.scales:
            reprojection_loss = self.compute_reprojection_loss(
                outputs[("color", 's', scale)], target)
            identity_reprojection_loss = self.compute_reprojection_loss(
                inputs[("color", 's', source_scale)],
                target) + torch.randn(reprojection_loss.shape).cuda() * 0.00001
            combined = torch.cat(
                (reprojection_loss, identity_reprojection_loss,
                 depth_hint_reproj_loss),
                dim=1)
            to_optimise, idxs = torch.min(combined, dim=1, keepdim=True)

            reprojection_loss_mask = (idxs !=
                                      1).float() * (1 - outputs['ssimMask'])
            depth_hint_loss_mask = (idxs == 2).float()

            losses["loss_depth/{}".format(scale)] = (
                reprojection_loss * reprojection_loss_mask).sum() / (
                    reprojection_loss_mask.sum() + 1e-7)
            losses["totLoss"] += losses["loss_depth/{}".format(
                scale)] / self.num_scales
            # proxy supervision loss
            if self.opt.read_stereo:
                valid_pixels = inputs['depth_hint_mask']

                depth_hint_loss = self.compute_proxy_supervised_loss(
                    outputs[('depth', 0, scale)], inputs['depth_hint'],
                    valid_pixels, depth_hint_loss_mask)
                depth_hint_loss = depth_hint_loss.sum() / (
                    depth_hint_loss_mask.sum() + 1e-7)
                losses['depth_hint_loss/{}'.format(scale)] = depth_hint_loss
                losses[
                    "totLoss"] += depth_hint_loss / self.num_scales * self.opt.depth_hint_param

            if self.opt.disparity_smoothness > 0:
                mult_disp = outputs[('disp', scale)]
                mean_disp = mult_disp.mean(2, True).mean(3, True)
                norm_disp = mult_disp / (mean_disp + 1e-7)
                losses["loss_smooth"] = get_smooth_loss(norm_disp,
                                                        target) / (2**scale)
                losses["totLoss"] += self.opt.disparity_smoothness * losses[
                    "loss_smooth"] / self.num_scales

        return losses

    @staticmethod
    def compute_proxy_supervised_loss(pred, target, valid_pixels, loss_mask):
        """ Compute proxy supervised loss (depth hint loss) for prediction.

            - valid_pixels is a mask of valid depth hint pixels (i.e. non-zero depth values).
            - loss_mask is a mask of where to apply the proxy supervision (i.e. the depth hint gave
            the smallest reprojection error)"""

        # first compute proxy supervised loss for all valid pixels
        depth_hint_loss = torch.log(torch.abs(target - pred) +
                                    1) * valid_pixels

        # only keep pixels where depth hints reprojection loss is smallest
        depth_hint_loss = depth_hint_loss * loss_mask

        return depth_hint_loss

    def compute_depth_losses(self, inputs, outputs, losses):
        """Compute depth metrics, to allow monitoring during training

        This isn't particularly accurate as it averages over the entire batch,
        so is only used to give an indication of validation performance
        """
        depth_pred = outputs[("depth", 0, 0)]
        depth_pred = torch.clamp(
            F.interpolate(depth_pred, [375, 1242],
                          mode="bilinear",
                          align_corners=False), 1e-3, 80)
        depth_pred = depth_pred.detach()

        depth_gt = inputs["depth_gt"]
        mask = depth_gt > 0

        # garg/eigen crop
        crop_mask = torch.zeros_like(mask)
        crop_mask[:, :, 153:371, 44:1197] = 1
        mask = mask * crop_mask
        depth_gt = depth_gt[mask]
        depth_pred = depth_pred[mask]
        depth_pred *= torch.median(depth_gt) / torch.median(depth_pred)

        depth_pred = torch.clamp(depth_pred, min=1e-3, max=80)

        depth_errors = compute_depth_errors(depth_gt, depth_pred)

        for i, metric in enumerate(self.depth_metric_names):
            losses[metric] = np.array(depth_errors[i].cpu())
        return depth_errors

    def log_time(self, batch_idx, duration, loss_depth, loss_tot):
        """Print a logging statement to the terminal
        """
        samples_per_sec = self.opt.batch_size / duration
        time_sofar = time.time() - self.start_time
        training_time_left = (self.num_total_steps / self.step -
                              1.0) * time_sofar if self.step > 0 else 0
        print_string = "epoch {:>3} | batch {:>6} | examples/s: {:5.1f}\nloss_depth: {:.5f} | loss_tot: {:.5f} | time elapsed: {} | time left: {}"
        print(
            print_string.format(self.epoch, batch_idx,
                                samples_per_sec, loss_depth, loss_tot,
                                sec_to_hm_str(time_sofar),
                                sec_to_hm_str(training_time_left)))

    def record_img(self, disp, semantic_gt, disp_morphed=None, mask=None):
        dirpath = os.path.join(
            "/media/shengjie/other/sceneUnderstanding/SDNET/visualization",
            self.opt.model_name)
        if not os.path.exists(dirpath):
            os.makedirs(dirpath)

        viewIndex = 0
        fig_seman = tensor2semantic(semantic_gt, ind=viewIndex, isGt=True)
        fig_disp = tensor2disp(disp, ind=viewIndex, vmax=0.09)
        overlay_org = pil.fromarray(
            (np.array(fig_disp) * 0.7 + np.array(fig_seman) * 0.3).astype(
                np.uint8))
        if disp_morphed is not None:
            fig_disp_morphed = tensor2disp(disp_morphed,
                                           ind=viewIndex,
                                           vmax=0.09)
            overlay_dst = pil.fromarray(
                (np.array(fig_disp_morphed) * 0.7 +
                 np.array(fig_seman) * 0.3).astype(np.uint8))

            fig_disp_masked = tensor2disp(mask, vmax=1, ind=viewIndex)
            fig_disp_masked_overlay = pil.fromarray(
                (np.array(fig_disp_masked) * 0.3 +
                 np.array(fig_seman) * 0.7).astype(np.uint8))
            combined_fig = pil.fromarray(
                np.concatenate([
                    np.array(overlay_org),
                    np.array(fig_disp),
                    np.array(overlay_dst),
                    np.array(fig_disp_morphed),
                    np.array(fig_disp_masked_overlay)
                ],
                               axis=0))
        else:
            combined_fig = pil.fromarray(
                np.concatenate([np.array(overlay_org),
                                np.array(fig_disp)],
                               axis=0))
        combined_fig.save(dirpath + '/' + str(self.step) + ".png")

    def log(self, mode, inputs, outputs, losses, writeImage=False):
        """Write an event to the tensorboard events file
        """
        writer = self.writers[mode]
        for l, v in losses.items():
            if l != 'totLoss':
                writer.add_scalar("{}".format(l), v, self.step)

    def save_opts(self):
        """Save options to disk so we know what we ran this experiment with
        """
        models_dir = os.path.join(self.log_path, "models")
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
        to_save = self.opt.__dict__.copy()

        with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
            json.dump(to_save, f, indent=2)

    def save_model(self, indentifier=None):
        """Save model weights to disk
        """
        if indentifier is None:
            save_folder = os.path.join(self.log_path, "models",
                                       "weights_{}".format(self.epoch))
        else:
            save_folder = os.path.join(self.log_path, "models",
                                       "weights_{}".format(indentifier))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        for model_name, model in self.models.items():
            save_path = os.path.join(save_folder, "{}.pth".format(model_name))
            if model_name == 'stable_encoder' or model_name == 'stable_depth':
                continue
            to_save = model.state_dict()
            if model_name == 'encoder':
                to_save['height'] = self.opt.height
                to_save['width'] = self.opt.width
                to_save['use_stereo'] = self.opt.use_stereo
            torch.save(to_save, save_path)

        save_path = os.path.join(save_folder, "{}.pth".format("adam"))
        torch.save(self.model_optimizer.state_dict(), save_path)
        print("save to %s" % save_folder)

    def load_model(self):
        """Load model(s) from disk
        """
        self.opt.load_weights_folder = os.path.expanduser(
            self.opt.load_weights_folder)

        assert os.path.isdir(self.opt.load_weights_folder), \
            "Cannot find folder {}".format(self.opt.load_weights_folder)
        print("loading model from folder {}".format(
            self.opt.load_weights_folder))

        for n in self.opt.models_to_load:
            print("Loading {} weights...".format(n))
            path = os.path.join(self.opt.load_weights_folder,
                                "{}.pth".format(n))
            model_dict = self.models[n].state_dict()
            pretrained_dict = torch.load(path)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            self.models[n].load_state_dict(model_dict)
        # loading adam state
        optimizer_load_path = os.path.join(self.opt.load_weights_folder,
                                           "adam.pth")
        if os.path.isfile(optimizer_load_path):
            print("Loading Adam weights")
            optimizer_dict = torch.load(optimizer_load_path)
            self.model_optimizer.load_state_dict(optimizer_dict)
        else:
            print("Cannot find Adam weights so Adam is randomly initialized")

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
def evaluate(opt):
    """Evaluates a pretrained model using a specified test set
    """
    MIN_DEPTH = 1e-3
    MAX_DEPTH = 80
    assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \
        "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo"

    if opt.ext_disp_to_eval is None:

        opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)

        assert os.path.isdir(opt.load_weights_folder), \
            "Cannot find a folder at {}".format(opt.load_weights_folder)

        print("-> Loading weights from {}".format(opt.load_weights_folder))

        filenames = readlines(
            os.path.join(splits_dir, opt.eval_split, "test_files.txt"))

        encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth")
        decoder_path = os.path.join(opt.load_weights_folder, "depth.pth")

        encoder_dict = torch.load(encoder_path)

        dataset = datasets.KITTIRAWDataset(opt.data_path,
                                           filenames,
                                           encoder_dict['height'],
                                           encoder_dict['width'], [0],
                                           4,
                                           is_train=False,
                                           load_semantics=opt.load_semantics,
                                           seman_path=opt.seman_path)

        dataloader = DataLoader(dataset,
                                opt.batch_size,
                                shuffle=False,
                                num_workers=opt.num_workers,
                                drop_last=False)

        encoder = networks.ResnetEncoder(opt.num_layers, False)
        depth_decoder = networks.DepthDecoder(encoder.num_ch_enc)

        if opt.bnMorphLoss:
            from bnmorph.bnmorph import BNMorph
            bnmorph = BNMorph(height=encoder_dict['height'],
                              width=encoder_dict['width']).cuda()
            if opt.post_process:
                tool = grad_computation_tools(
                    batch_size=opt.batch_size * 2,
                    height=encoder_dict['height'],
                    width=encoder_dict['width']).cuda()
            else:
                tool = grad_computation_tools(
                    batch_size=opt.batch_size,
                    height=encoder_dict['height'],
                    width=encoder_dict['width']).cuda()

        model_dict = encoder.state_dict()
        encoder.load_state_dict(
            {k: v
             for k, v in encoder_dict.items() if k in model_dict})
        depth_decoder.load_state_dict(torch.load(decoder_path))

        encoder.cuda()
        encoder.eval()
        depth_decoder.cuda()
        depth_decoder.eval()

        pred_disps = []
        count = 0
        with torch.no_grad():
            for data in dataloader:
                input_color = data[("color", 0, 0)].cuda()
                if opt.post_process:
                    input_color = torch.cat(
                        (input_color, torch.flip(input_color, [3])), 0)
                    if 'seman_gt' in data:
                        data['seman_gt'] = torch.cat(
                            (data['seman_gt'], torch.flip(
                                data['seman_gt'], [3])), 0)

                features = encoder(input_color)
                outputs = dict()
                outputs.update(depth_decoder(features))

                if opt.bnMorphLoss:
                    for key, ipt in data.items():
                        if not (key == 'height' or key == 'width'
                                or key == 'tag' or key == 'cts_meta'
                                or key == 'file_add'):
                            data[key] = ipt.to(torch.device("cuda"))

                    disparity_grad_bin = tool.get_disparityEdge(outputs['disp',
                                                                        0])
                    semantics_grad_bin = tool.get_semanticsEdge(
                        data['seman_gt'])

                    morphedx, morphedy, coeff = bnmorph.find_corresponding_pts(
                        disparity_grad_bin, semantics_grad_bin)
                    morphedx = (morphedx /
                                (encoder_dict['width'] - 1) - 0.5) * 2
                    morphedy = (morphedy /
                                (encoder_dict['height'] - 1) - 0.5) * 2
                    grid = torch.cat([morphedx, morphedy],
                                     dim=1).permute(0, 2, 3, 1)
                    dispMaps_morphed = F.grid_sample(outputs['disp', 0],
                                                     grid,
                                                     padding_mode="border")
                    outputs[("disp", 0)] = dispMaps_morphed

                count = count + 1
                pred_disp, _ = disp_to_depth(outputs[("disp", 0)],
                                             opt.min_depth, opt.max_depth)
                pred_disp = pred_disp.cpu()[:, 0].numpy()

                if opt.post_process:
                    N = pred_disp.shape[0] // 2
                    pred_disp = batch_post_process_disparity(
                        pred_disp[:N], pred_disp[N:, :, ::-1])
                pred_disps.append(pred_disp)

        pred_disps = np.concatenate(pred_disps)
    else:
        # Load predictions from file
        print("-> Loading predictions from {}".format(opt.ext_disp_to_eval))
        pred_disps = np.load(opt.ext_disp_to_eval)

        if opt.eval_eigen_to_benchmark:
            eigen_to_benchmark_ids = np.load(
                os.path.join(splits_dir, "benchmark",
                             "eigen_to_benchmark_ids.npy"))

            pred_disps = pred_disps[eigen_to_benchmark_ids]

    if opt.save_pred_disps:
        output_path = os.path.join(opt.load_weights_folder,
                                   "disps_{}_split.npy".format(opt.eval_split))
        print("-> Saving predicted disparities to ", output_path)
        np.save(output_path, pred_disps)

    if opt.no_eval:
        print("-> Evaluation disabled. Done.")
        quit()

    elif opt.eval_split == 'benchmark':
        save_dir = os.path.join(opt.load_weights_folder,
                                "benchmark_predictions")
        print("-> Saving out benchmark predictions to {}".format(save_dir))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        for idx in range(len(pred_disps)):
            disp_resized = cv2.resize(pred_disps[idx], (1216, 352))
            depth = STEREO_SCALE_FACTOR / disp_resized
            depth = np.clip(depth, 0, 80)
            depth = np.uint16(depth * 256)
            save_path = os.path.join(save_dir, "{:010d}.png".format(idx))
            cv2.imwrite(save_path, depth)

        print(
            "-> No ground truth is available for the KITTI benchmark, so not evaluating. Done."
        )
        quit()

    gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
    gt_depths = np.load(gt_path,
                        fix_imports=True,
                        encoding='latin1',
                        allow_pickle=True)["data"]

    print("-> Evaluating")

    if opt.eval_stereo:
        print("   Stereo evaluation - "
              "disabling median scaling, scaling by {}".format(
                  STEREO_SCALE_FACTOR))
        opt.disable_median_scaling = True
        opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR
    else:
        print("   Mono evaluation - using median scaling")

    errors = []
    ratios = []

    for i in range(pred_disps.shape[0]):

        gt_depth = gt_depths[i]
        gt_height, gt_width = gt_depth.shape[:2]

        pred_disp = pred_disps[i]
        pred_disp = cv2.resize(pred_disp, (gt_width, gt_height))
        pred_depth = 1 / pred_disp

        if opt.eval_split == "eigen" or opt.UseCustTest:
            mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)

            crop = np.array([
                0.40810811 * gt_height, 0.99189189 * gt_height,
                0.03594771 * gt_width, 0.96405229 * gt_width
            ]).astype(np.int32)
            crop_mask = np.zeros(mask.shape)
            crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
            mask = np.logical_and(mask, crop_mask)
        else:
            mask = gt_depth > 0

        pred_depth = pred_depth[mask]
        gt_depth = gt_depth[mask]

        pred_depth *= opt.pred_depth_scale_factor
        if not opt.disable_median_scaling:
            ratio = np.median(gt_depth) / np.median(pred_depth)
            ratios.append(ratio)
            pred_depth *= ratio

        pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
        pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH
        errors.append(
            compute_errors(
                gt_depth,
                pred_depth,
                UseGtMedianScaling=(opt.UseGtMedianScaling == True)))

    if not opt.disable_median_scaling:
        ratios = np.array(ratios)
        med = np.median(ratios)
        print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(
            med, np.std(ratios / med)))

    mean_errors = np.array(errors).mean(0)

    print("\n  " +
          ("{:>8} | " * 7
           ).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3"))
    print(("&{: 8.3f}  " * 7).format(*mean_errors.tolist()) + "\\\\")
    print("\n-> Done!")
Example #12
0
    def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)

        # checking height and width are multiples of 32
        assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
        assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"

        self.models = {}
        self.parameters_to_train = []

        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        self.num_scales = len(self.opt.scales)
        self.num_input_frames = len(self.opt.frame_ids)
        self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames

        assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"

        self.use_pose_net = not (self.opt.use_stereo
                                 and self.opt.frame_ids == [0])

        if self.opt.use_stereo:
            self.opt.frame_ids.append("s")

        self.models["encoder"] = networks.ResnetEncoder(
            self.opt.num_layers, self.opt.weights_init == "pretrained")
        self.models["encoder"].to(self.device)
        self.parameters_to_train += list(self.models["encoder"].parameters())

        self.models["depth"] = networks.DepthDecoder(
            self.models["encoder"].num_ch_enc,
            self.opt.scales,
            use_ordConv=self.opt.isOrdConv)
        self.models["depth"].to(self.device)
        self.parameters_to_train += list(self.models["depth"].parameters())

        if self.use_pose_net:
            if self.opt.pose_model_type == "separate_resnet":
                self.models["pose_encoder"] = networks.ResnetEncoder(
                    self.opt.num_layers,
                    self.opt.weights_init == "pretrained",
                    num_input_images=self.num_pose_frames)

                self.models["pose_encoder"].to(self.device)
                self.parameters_to_train += list(
                    self.models["pose_encoder"].parameters())

                self.models["pose"] = networks.PoseDecoder(
                    self.models["pose_encoder"].num_ch_enc,
                    num_input_features=1,
                    num_frames_to_predict_for=2)

            elif self.opt.pose_model_type == "shared":
                self.models["pose"] = networks.PoseDecoder(
                    self.models["encoder"].num_ch_enc, self.num_pose_frames)

            elif self.opt.pose_model_type == "posecnn":
                self.models["pose"] = networks.PoseCNN(
                    self.num_input_frames if self.opt.pose_model_input ==
                    "all" else 2)

            self.models["pose"].to(self.device)
            self.parameters_to_train += list(self.models["pose"].parameters())

        if self.opt.predictive_mask:
            assert self.opt.disable_automasking, \
                "When using predictive_mask, please disable automasking with --disable_automasking"

            # Our implementation of the predictive masking baseline has the the same architecture
            # as our depth decoder. We predict a separate mask for each source frame.
            self.models["predictive_mask"] = networks.DepthDecoder(
                self.models["encoder"].num_ch_enc,
                self.opt.scales,
                num_output_channels=(len(self.opt.frame_ids) - 1))
            self.models["predictive_mask"].to(self.device)
            self.parameters_to_train += list(
                self.models["predictive_mask"].parameters())

        self.model_optimizer = optim.Adam(self.parameters_to_train,
                                          self.opt.learning_rate)
        self.model_lr_scheduler = optim.lr_scheduler.StepLR(
            self.model_optimizer, self.opt.scheduler_step_size, 0.1)

        if self.opt.load_weights_folder is not None:
            self.load_model()

        print("Training model named:\n  ", self.opt.model_name)
        print("Models and tensorboard events files are saved to:\n  ",
              self.opt.log_dir)
        print("Training is using:\n  ", self.device)

        # data
        datasets_dict = {
            "kitti": datasets.KITTIRAWDataset,
            "kitti_odom": datasets.KITTIOdomDataset
        }
        self.dataset = datasets_dict[self.opt.dataset]

        fpath = os.path.join(os.path.dirname(__file__), "splits",
                             self.opt.split, "{}_files.txt")

        train_filenames = readlines(fpath.format("train"))
        val_filenames = readlines(fpath.format("val"))
        img_ext = '.png'

        num_train_samples = len(train_filenames)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        train_dataset = self.dataset(self.opt.data_path,
                                     train_filenames,
                                     self.opt.height,
                                     self.opt.width,
                                     self.opt.frame_ids,
                                     4,
                                     is_train=True,
                                     img_ext=img_ext)
        self.train_loader = DataLoader(train_dataset,
                                       self.opt.batch_size,
                                       True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)
        val_dataset = self.dataset(self.opt.data_path,
                                   val_filenames,
                                   self.opt.height,
                                   self.opt.width,
                                   self.opt.frame_ids,
                                   4,
                                   is_train=False,
                                   img_ext=img_ext)
        self.val_loader = DataLoader(val_dataset,
                                     self.opt.batch_size,
                                     True,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)
        self.val_iter = iter(self.val_loader)

        self.writers = {}
        for mode in ["train", "val"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))

        if not self.opt.no_ssim:
            self.ssim = SSIM()
            self.ssim.to(self.device)

        self.backproject_depth = {}
        self.project_3d = {}
        for scale in self.opt.scales:
            h = self.opt.height // (2**scale)
            w = self.opt.width // (2**scale)

            self.backproject_depth[scale] = BackprojectDepth(
                self.opt.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
            self.project_3d[scale].to(self.device)

        self.depth_metric_names = [
            "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1",
            "da/a2", "da/a3"
        ]

        print("Using split:\n  ", self.opt.split)
        print(
            "There are {:d} training items and {:d} validation items\n".format(
                len(train_dataset), len(val_dataset)))

        self.save_opts()

        if self.opt.isCudaMorphing:
            self.foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18]
            self.auto_morph = BNMorph(height=self.opt.height,
                                      width=self.opt.width,
                                      senseRange=20).cuda()
            self.tool = grad_computation_tools(batch_size=self.opt.batch_size,
                                               height=self.opt.height,
                                               width=self.opt.width).cuda()
Example #13
0
class Trainer:
    def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)

        # checking height and width are multiples of 32
        assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
        assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"

        self.models = {}
        self.parameters_to_train = []

        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        self.num_scales = len(self.opt.scales)
        self.num_input_frames = len(self.opt.frame_ids)
        self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames

        assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"

        self.use_pose_net = not (self.opt.use_stereo
                                 and self.opt.frame_ids == [0])

        if self.opt.use_stereo:
            self.opt.frame_ids.append("s")

        self.models["encoder"] = networks.ResnetEncoder(
            self.opt.num_layers, self.opt.weights_init == "pretrained")
        self.models["encoder"].to(self.device)
        self.parameters_to_train += list(self.models["encoder"].parameters())

        self.models["depth"] = networks.DepthDecoder(
            self.models["encoder"].num_ch_enc,
            self.opt.scales,
            use_ordConv=self.opt.isOrdConv)
        self.models["depth"].to(self.device)
        self.parameters_to_train += list(self.models["depth"].parameters())

        if self.use_pose_net:
            if self.opt.pose_model_type == "separate_resnet":
                self.models["pose_encoder"] = networks.ResnetEncoder(
                    self.opt.num_layers,
                    self.opt.weights_init == "pretrained",
                    num_input_images=self.num_pose_frames)

                self.models["pose_encoder"].to(self.device)
                self.parameters_to_train += list(
                    self.models["pose_encoder"].parameters())

                self.models["pose"] = networks.PoseDecoder(
                    self.models["pose_encoder"].num_ch_enc,
                    num_input_features=1,
                    num_frames_to_predict_for=2)

            elif self.opt.pose_model_type == "shared":
                self.models["pose"] = networks.PoseDecoder(
                    self.models["encoder"].num_ch_enc, self.num_pose_frames)

            elif self.opt.pose_model_type == "posecnn":
                self.models["pose"] = networks.PoseCNN(
                    self.num_input_frames if self.opt.pose_model_input ==
                    "all" else 2)

            self.models["pose"].to(self.device)
            self.parameters_to_train += list(self.models["pose"].parameters())

        if self.opt.predictive_mask:
            assert self.opt.disable_automasking, \
                "When using predictive_mask, please disable automasking with --disable_automasking"

            # Our implementation of the predictive masking baseline has the the same architecture
            # as our depth decoder. We predict a separate mask for each source frame.
            self.models["predictive_mask"] = networks.DepthDecoder(
                self.models["encoder"].num_ch_enc,
                self.opt.scales,
                num_output_channels=(len(self.opt.frame_ids) - 1))
            self.models["predictive_mask"].to(self.device)
            self.parameters_to_train += list(
                self.models["predictive_mask"].parameters())

        self.model_optimizer = optim.Adam(self.parameters_to_train,
                                          self.opt.learning_rate)
        self.model_lr_scheduler = optim.lr_scheduler.StepLR(
            self.model_optimizer, self.opt.scheduler_step_size, 0.1)

        if self.opt.load_weights_folder is not None:
            self.load_model()

        print("Training model named:\n  ", self.opt.model_name)
        print("Models and tensorboard events files are saved to:\n  ",
              self.opt.log_dir)
        print("Training is using:\n  ", self.device)

        # data
        datasets_dict = {
            "kitti": datasets.KITTIRAWDataset,
            "kitti_odom": datasets.KITTIOdomDataset
        }
        self.dataset = datasets_dict[self.opt.dataset]

        fpath = os.path.join(os.path.dirname(__file__), "splits",
                             self.opt.split, "{}_files.txt")

        train_filenames = readlines(fpath.format("train"))
        val_filenames = readlines(fpath.format("val"))
        img_ext = '.png'

        num_train_samples = len(train_filenames)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        train_dataset = self.dataset(self.opt.data_path,
                                     train_filenames,
                                     self.opt.height,
                                     self.opt.width,
                                     self.opt.frame_ids,
                                     4,
                                     is_train=True,
                                     img_ext=img_ext)
        self.train_loader = DataLoader(train_dataset,
                                       self.opt.batch_size,
                                       True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)
        val_dataset = self.dataset(self.opt.data_path,
                                   val_filenames,
                                   self.opt.height,
                                   self.opt.width,
                                   self.opt.frame_ids,
                                   4,
                                   is_train=False,
                                   img_ext=img_ext)
        self.val_loader = DataLoader(val_dataset,
                                     self.opt.batch_size,
                                     True,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)
        self.val_iter = iter(self.val_loader)

        self.writers = {}
        for mode in ["train", "val"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))

        if not self.opt.no_ssim:
            self.ssim = SSIM()
            self.ssim.to(self.device)

        self.backproject_depth = {}
        self.project_3d = {}
        for scale in self.opt.scales:
            h = self.opt.height // (2**scale)
            w = self.opt.width // (2**scale)

            self.backproject_depth[scale] = BackprojectDepth(
                self.opt.batch_size, h, w)
            self.backproject_depth[scale].to(self.device)

            self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
            self.project_3d[scale].to(self.device)

        self.depth_metric_names = [
            "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1",
            "da/a2", "da/a3"
        ]

        print("Using split:\n  ", self.opt.split)
        print(
            "There are {:d} training items and {:d} validation items\n".format(
                len(train_dataset), len(val_dataset)))

        self.save_opts()

        if self.opt.isCudaMorphing:
            self.foregroundType = [5, 6, 7, 11, 12, 13, 14, 15, 16, 17, 18]
            self.auto_morph = BNMorph(height=self.opt.height,
                                      width=self.opt.width,
                                      senseRange=20).cuda()
            self.tool = grad_computation_tools(batch_size=self.opt.batch_size,
                                               height=self.opt.height,
                                               width=self.opt.width).cuda()

    def set_train(self):
        """Convert all models to training mode
        """
        for m in self.models.values():
            m.train()

    def set_eval(self):
        """Convert all models to testing/evaluation mode
        """
        for m in self.models.values():
            m.eval()

    def train(self):
        """Run the entire training pipeline
        """
        self.epoch = 0
        self.step = 0
        self.start_time = time.time()
        for self.epoch in range(self.opt.num_epochs):
            self.run_epoch()
            if (self.epoch + 1) % self.opt.save_frequency == 0:
                self.save_model()

    def run_epoch(self):
        """Run a single epoch of training and validation
        """
        self.model_lr_scheduler.step()

        print("Training")
        self.set_train()

        for batch_idx, inputs in enumerate(self.train_loader):

            before_op_time = time.time()

            outputs, losses = self.process_batch(inputs)

            self.model_optimizer.zero_grad()
            losses["loss"].backward()
            self.model_optimizer.step()

            duration = time.time() - before_op_time

            if np.mod(batch_idx, self.opt.val_frequency) == 0:
                self.log_time(batch_idx, duration, losses["loss"].cpu().data)

                if "depth_gt" in inputs:
                    self.compute_depth_losses(inputs, outputs, losses)

                self.log("train", inputs, outputs, losses)
                self.val()
                if self.opt.writeImg:
                    self.writeImg(outputs['disp', 0], inputs['seman_gt'],
                                  outputs['pred_morphed'])
            self.step += 1

    def writeImg(self, disp, semantic_gt, disp_morphed=None):
        viewIndex = 0
        fig_seman = tensor2semantic(semantic_gt, ind=viewIndex, isGt=True)
        fig_disp = tensor2disp(disp, ind=viewIndex, vmax=0.09)
        overlay_org = pil.fromarray(
            (np.array(fig_disp) * 0.7 + np.array(fig_seman) * 0.3).astype(
                np.uint8))

        if disp_morphed is not None:
            fig_disp_morphed = tensor2disp(disp_morphed,
                                           ind=viewIndex,
                                           vmax=0.09)
            overlay_dst = pil.fromarray(
                (np.array(fig_disp_morphed) * 0.7 +
                 np.array(fig_seman) * 0.3).astype(np.uint8))

            disp_masked = disp
            fig_disp_masked = tensor2disp(disp_masked,
                                          vmax=0.09,
                                          ind=viewIndex)
            fig_disp_masked_overlay = pil.fromarray(
                (np.array(fig_disp_masked) * 0.7 +
                 np.array(fig_seman) * 0.3).astype(np.uint8))

            combined_fig = pil.fromarray(
                np.concatenate([
                    np.array(overlay_org),
                    np.array(fig_disp),
                    np.array(overlay_dst),
                    np.array(fig_disp_morphed),
                    np.array(fig_disp_masked_overlay)
                ],
                               axis=0))
            sv_path = os.path.join(
                '/media/shengjie/other/sceneUnderstanding/Godard19/visualization/mono+stereo_1024x320',
                str(self.step) + '.png')
            combined_fig.save(sv_path)

    def process_batch(self, inputs):
        """Pass a minibatch through the network and generate images and losses
        """
        for key, ipt in inputs.items():
            inputs[key] = ipt.to(self.device)

        if self.opt.pose_model_type == "shared":
            # If we are using a shared encoder for both depth and pose (as advocated
            # in monodepthv1), then all images are fed separately through the depth encoder.
            all_color_aug = torch.cat(
                [inputs[("color_aug", i, 0)] for i in self.opt.frame_ids])
            all_features = self.models["encoder"](all_color_aug)
            all_features = [
                torch.split(f, self.opt.batch_size) for f in all_features
            ]

            features = {}
            for i, k in enumerate(self.opt.frame_ids):
                features[k] = [f[i] for f in all_features]

            outputs = self.models["depth"](features[0])
        else:
            # Otherwise, we only feed the image with frame_id 0 through the depth encoder
            features = self.models["encoder"](inputs["color_aug", 0, 0])
            outputs = self.models["depth"](features)

        if self.opt.predictive_mask:
            outputs["predictive_mask"] = self.models["predictive_mask"](
                features)

        if self.use_pose_net:
            outputs.update(self.predict_poses(inputs, features))

        self.generate_images_pred(inputs, outputs)
        losses = self.compute_losses(inputs, outputs)

        return outputs, losses

    def predict_poses(self, inputs, features):
        """Predict poses between input frames for monocular sequences.
        """
        outputs = {}
        if self.num_pose_frames == 2:
            # In this setting, we compute the pose to each source frame via a
            # separate forward pass through the pose network.

            # select what features the pose network takes as input
            if self.opt.pose_model_type == "shared":
                pose_feats = {f_i: features[f_i] for f_i in self.opt.frame_ids}
            else:
                pose_feats = {
                    f_i: inputs["color_aug", f_i, 0]
                    for f_i in self.opt.frame_ids
                }

            for f_i in self.opt.frame_ids[1:]:
                if f_i != "s":
                    # To maintain ordering we always pass frames in temporal order
                    if f_i < 0:
                        pose_inputs = [pose_feats[f_i], pose_feats[0]]
                    else:
                        pose_inputs = [pose_feats[0], pose_feats[f_i]]

                    if self.opt.pose_model_type == "separate_resnet":
                        pose_inputs = [
                            self.models["pose_encoder"](torch.cat(
                                pose_inputs, 1))
                        ]
                    elif self.opt.pose_model_type == "posecnn":
                        pose_inputs = torch.cat(pose_inputs, 1)

                    axisangle, translation = self.models["pose"](pose_inputs)
                    outputs[("axisangle", 0, f_i)] = axisangle
                    outputs[("translation", 0, f_i)] = translation

                    # Invert the matrix if the frame id is negative
                    outputs[("cam_T_cam", 0,
                             f_i)] = transformation_from_parameters(
                                 axisangle[:, 0],
                                 translation[:, 0],
                                 invert=(f_i < 0))

        else:
            # Here we input all frames to the pose net (and predict all poses) together
            if self.opt.pose_model_type in ["separate_resnet", "posecnn"]:
                pose_inputs = torch.cat([
                    inputs[("color_aug", i, 0)]
                    for i in self.opt.frame_ids if i != "s"
                ], 1)

                if self.opt.pose_model_type == "separate_resnet":
                    pose_inputs = [self.models["pose_encoder"](pose_inputs)]

            elif self.opt.pose_model_type == "shared":
                pose_inputs = [
                    features[i] for i in self.opt.frame_ids if i != "s"
                ]

            axisangle, translation = self.models["pose"](pose_inputs)

            for i, f_i in enumerate(self.opt.frame_ids[1:]):
                if f_i != "s":
                    outputs[("axisangle", 0, f_i)] = axisangle
                    outputs[("translation", 0, f_i)] = translation
                    outputs[("cam_T_cam", 0,
                             f_i)] = transformation_from_parameters(
                                 axisangle[:, i], translation[:, i])

        return outputs

    def val(self):
        """Validate the model on a single minibatch
        """
        self.set_eval()
        try:
            inputs = self.val_iter.next()
        except StopIteration:
            self.val_iter = iter(self.val_loader)
            inputs = self.val_iter.next()

        with torch.no_grad():
            outputs, losses = self.process_batch(inputs)

            if "depth_gt" in inputs:
                self.compute_depth_losses(inputs, outputs, losses)

            self.log("val", inputs, outputs, losses)
            del inputs, outputs, losses

        self.set_train()

    def generate_images_pred(self, inputs, outputs):
        """Generate the warped (reprojected) color images for a minibatch.
        Generated images are saved into the `outputs` dictionary.
        """
        for scale in self.opt.scales:
            disp = outputs[("disp", scale)]
            if self.opt.v1_multiscale:
                source_scale = scale
            else:
                disp = F.interpolate(disp, [self.opt.height, self.opt.width],
                                     mode="bilinear",
                                     align_corners=False)
                source_scale = 0

            _, depth = disp_to_depth(disp, self.opt.min_depth,
                                     self.opt.max_depth)

            outputs[("depth", 0, scale)] = depth

            for i, frame_id in enumerate(self.opt.frame_ids[1:]):

                if frame_id == "s":
                    T = inputs["stereo_T"]
                else:
                    T = outputs[("cam_T_cam", 0, frame_id)]

                # from the authors of https://arxiv.org/abs/1712.00175
                if self.opt.pose_model_type == "posecnn":

                    axisangle = outputs[("axisangle", 0, frame_id)]
                    translation = outputs[("translation", 0, frame_id)]

                    inv_depth = 1 / depth
                    mean_inv_depth = inv_depth.mean(3, True).mean(2, True)

                    T = transformation_from_parameters(
                        axisangle[:, 0],
                        translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0)

                cam_points = self.backproject_depth[source_scale](
                    depth, inputs[("inv_K", source_scale)])
                pix_coords = self.project_3d[source_scale](
                    cam_points, inputs[("K", source_scale)], T)

                outputs[("sample", frame_id, scale)] = pix_coords

                if scale == 0:
                    grad_proj_msak = (pix_coords[:, :, :, 0] >
                                      -1) * (pix_coords[:, :, :, 1] > -1) * (
                                          pix_coords[:, :, :, 0] <
                                          1) * (pix_coords[:, :, :, 1] < 1)
                    grad_proj_msak = grad_proj_msak.unsqueeze(1).float()
                    outputs['grad_proj_msak'] = grad_proj_msak

                outputs[("color", frame_id, scale)] = F.grid_sample(
                    inputs[("color", frame_id, source_scale)],
                    outputs[("sample", frame_id, scale)],
                    padding_mode="border")

                if not self.opt.disable_automasking:
                    outputs[("color_identity", frame_id, scale)] = \
                        inputs[("color", frame_id, source_scale)]

    def compute_reprojection_loss(self, pred, target):
        """Computes reprojection loss between a batch of predicted and target images
        """
        abs_diff = torch.abs(target - pred)
        l1_loss = abs_diff.mean(1, True)

        if self.opt.no_ssim:
            reprojection_loss = l1_loss
        else:
            ssim_loss = self.ssim(pred, target).mean(1, True)
            reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss

        return reprojection_loss

    def compute_losses(self, inputs, outputs):
        """Compute the reprojection and smoothness losses for a minibatch
        """
        losses = {}
        total_loss = 0

        for scale in self.opt.scales:
            loss = 0
            reprojection_losses = []

            if self.opt.v1_multiscale:
                source_scale = scale
            else:
                source_scale = 0

            disp = outputs[("disp", scale)]
            color = inputs[("color", 0, scale)]
            target = inputs[("color", 0, source_scale)]

            for frame_id in self.opt.frame_ids[1:]:
                pred = outputs[("color", frame_id, scale)]
                reprojection_losses.append(
                    self.compute_reprojection_loss(pred, target))

            reprojection_losses = torch.cat(reprojection_losses, 1)

            if not self.opt.disable_automasking:
                identity_reprojection_losses = []
                for frame_id in self.opt.frame_ids[1:]:
                    pred = inputs[("color", frame_id, source_scale)]
                    identity_reprojection_losses.append(
                        self.compute_reprojection_loss(pred, target))

                identity_reprojection_losses = torch.cat(
                    identity_reprojection_losses, 1)

                if self.opt.avg_reprojection:
                    identity_reprojection_loss = identity_reprojection_losses.mean(
                        1, keepdim=True)
                else:
                    # save both images, and do min all at once below
                    identity_reprojection_loss = identity_reprojection_losses

            elif self.opt.predictive_mask:
                # use the predicted mask
                mask = outputs["predictive_mask"]["disp", scale]
                if not self.opt.v1_multiscale:
                    mask = F.interpolate(mask,
                                         [self.opt.height, self.opt.width],
                                         mode="bilinear",
                                         align_corners=False)

                reprojection_losses *= mask

                # add a loss pushing mask to 1 (using nn.BCELoss for stability)
                weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(
                    mask.shape).cuda())
                loss += weighting_loss.mean()

            if self.opt.avg_reprojection:
                reprojection_loss = reprojection_losses.mean(1, keepdim=True)
            else:
                reprojection_loss = reprojection_losses

            if not self.opt.disable_automasking:
                # add random numbers to break ties
                identity_reprojection_loss += torch.randn(
                    identity_reprojection_loss.shape).cuda() * 0.00001

                combined = torch.cat(
                    (identity_reprojection_loss, reprojection_loss), dim=1)
            else:
                combined = reprojection_loss

            if combined.shape[1] == 1:
                to_optimise = combined
            else:
                to_optimise, idxs = torch.min(combined, dim=1)

            if not self.opt.disable_automasking:
                outputs["identity_selection/{}".format(scale)] = (
                    idxs > identity_reprojection_loss.shape[1] - 1).float()

            loss += to_optimise.mean()

            mean_disp = disp.mean(2, True).mean(3, True)
            norm_disp = disp / (mean_disp + 1e-7)
            smooth_loss = get_smooth_loss(norm_disp, color)

            loss += self.opt.disparity_smoothness * smooth_loss / (2**scale)
            total_loss += loss
            losses["loss/{}".format(scale)] = loss

        total_loss /= self.num_scales

        if self.opt.isCudaMorphing and (self.epoch > 2
                                        or self.opt.is_no_delay):
            with torch.no_grad():
                stable_disp = outputs['disp', 0]
                foregroundMapGt = torch.ones(
                    [self.opt.batch_size, 1, self.opt.height, self.opt.width],
                    dtype=torch.uint8,
                    device=torch.device("cuda"))
                for m in self.foregroundType:
                    foregroundMapGt = foregroundMapGt * (inputs['seman_gt'] !=
                                                         m)
                foregroundMapGt = (1 - foregroundMapGt).float()

                disparity_grad = torch.abs(
                    self.tool.convDispx(outputs['disp', 0])) + torch.abs(
                        self.tool.convDispy(outputs['disp', 0]))
                semantics_grad = torch.abs(
                    self.tool.convDispx(foregroundMapGt)) + torch.abs(
                        self.tool.convDispy(foregroundMapGt))
                disparity_grad = disparity_grad * self.tool.zero_mask
                semantics_grad = semantics_grad * self.tool.zero_mask

                disparity_grad_bin = disparity_grad > self.tool.disparityTh
                semantics_grad_bin = semantics_grad > self.tool.semanticsTh

                morphedx, morphedy, ocoeff = self.auto_morph.find_corresponding_pts(
                    disparity_grad_bin,
                    semantics_grad_bin,
                    pixel_distance_weight=20)

                morphedx = (morphedx / (self.opt.width - 1) - 0.5) * 2
                morphedy = (morphedy / (self.opt.height - 1) - 0.5) * 2
                grid = torch.cat([morphedx, morphedy],
                                 dim=1).permute(0, 2, 3, 1)
                dispMaps_morphed = F.grid_sample(stable_disp.detach(),
                                                 grid,
                                                 padding_mode="border")
                scaledDisp, depth = disp_to_depth(dispMaps_morphed,
                                                  self.opt.min_depth,
                                                  self.opt.max_depth)
                frame_id = "s"
                T = inputs["stereo_T"]
                cam_points = self.backproject_depth[0](depth,
                                                       inputs[("inv_K", 0)])
                pix_coords = self.project_3d[0](cam_points, inputs[("K", 0)],
                                                T)
                morphed_rgb = F.grid_sample(inputs[("color", frame_id, 0)],
                                            pix_coords,
                                            padding_mode="border")

                ssim_val_morph = self.compute_reprojection_loss(
                    morphed_rgb, inputs[('color', 0, 0)])
                reprojection_loss_min, _ = torch.min(reprojection_loss,
                                                     dim=1,
                                                     keepdim=True)
                selector_mask = (reprojection_loss_min - ssim_val_morph >
                                 0).float() * outputs['grad_proj_msak']
                outputs['pred_morphed'] = dispMaps_morphed
            losses["similarity_loss"] = torch.sum(
                torch.log(1 + torch.abs(dispMaps_morphed - outputs['disp', 0]))
                * selector_mask) / (torch.sum(selector_mask) + 1)
            total_loss = total_loss + losses[
                "similarity_loss"] * self.opt.morphScale

        losses["loss"] = total_loss
        return losses

    def compute_depth_losses(self, inputs, outputs, losses):
        """Compute depth metrics, to allow monitoring during training

        This isn't particularly accurate as it averages over the entire batch,
        so is only used to give an indication of validation performance
        """
        depth_pred = outputs[("depth", 0, 0)]
        depth_pred = torch.clamp(
            F.interpolate(depth_pred, [375, 1242],
                          mode="bilinear",
                          align_corners=False), 1e-3, 80)
        depth_pred = depth_pred.detach()

        depth_gt = inputs["depth_gt"]
        mask = depth_gt > 0

        # garg/eigen crop
        crop_mask = torch.zeros_like(mask)
        crop_mask[:, :, 153:371, 44:1197] = 1
        mask = mask * crop_mask

        depth_gt = depth_gt[mask]
        depth_pred = depth_pred[mask]
        depth_pred *= torch.median(depth_gt) / torch.median(depth_pred)

        depth_pred = torch.clamp(depth_pred, min=1e-3, max=80)

        depth_errors = compute_depth_errors(depth_gt, depth_pred)

        for i, metric in enumerate(self.depth_metric_names):
            losses[metric] = np.array(depth_errors[i].cpu())

    def log_time(self, batch_idx, duration, loss):
        """Print a logging statement to the terminal
        """
        samples_per_sec = self.opt.batch_size / duration
        time_sofar = time.time() - self.start_time
        training_time_left = (self.num_total_steps / self.step -
                              1.0) * time_sofar if self.step > 0 else 0
        print_string = "epoch {:>3} | batch {:>6} | examples/s: {:5.1f}" + \
            " | loss: {:.5f} | time elapsed: {} | time left: {}"
        print(
            print_string.format(self.epoch, batch_idx, samples_per_sec, loss,
                                sec_to_hm_str(time_sofar),
                                sec_to_hm_str(training_time_left)))

    def log(self, mode, inputs, outputs, losses):
        """Write an event to the tensorboard events file
        """
        writer = self.writers[mode]
        for l, v in losses.items():
            writer.add_scalar("{}".format(l), v, self.step)

        # for j in range(min(4, self.opt.batch_size)):  # write a maxmimum of four images
        #     for s in self.opt.scales:
        #         for frame_id in self.opt.frame_ids:
        #             writer.add_image(
        #                 "color_{}_{}/{}".format(frame_id, s, j),
        #                 inputs[("color", frame_id, s)][j].data, self.step)
        #             if s == 0 and frame_id != 0:
        #                 writer.add_image(
        #                     "color_pred_{}_{}/{}".format(frame_id, s, j),
        #                     outputs[("color", frame_id, s)][j].data, self.step)
        #
        #         writer.add_image(
        #             "disp_{}/{}".format(s, j),
        #             normalize_image(outputs[("disp", s)][j]), self.step)
        #
        #         if self.opt.predictive_mask:
        #             for f_idx, frame_id in enumerate(self.opt.frame_ids[1:]):
        #                 writer.add_image(
        #                     "predictive_mask_{}_{}/{}".format(frame_id, s, j),
        #                     outputs["predictive_mask"][("disp", s)][j, f_idx][None, ...],
        #                     self.step)
        #
        #         elif not self.opt.disable_automasking:
        #             writer.add_image(
        #                 "automask_{}/{}".format(s, j),
        #                 outputs["identity_selection/{}".format(s)][j][None, ...], self.step)

    def save_opts(self):
        """Save options to disk so we know what we ran this experiment with
        """
        models_dir = os.path.join(self.log_path, "models")
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
        to_save = self.opt.__dict__.copy()

        with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
            json.dump(to_save, f, indent=2)

    def save_model(self):
        """Save model weights to disk
        """
        save_folder = os.path.join(self.log_path, "models",
                                   "weights_{}".format(self.epoch))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        for model_name, model in self.models.items():
            save_path = os.path.join(save_folder, "{}.pth".format(model_name))
            to_save = model.state_dict()
            if model_name == 'encoder':
                # save the sizes - these are needed at prediction time
                to_save['height'] = self.opt.height
                to_save['width'] = self.opt.width
                to_save['use_stereo'] = self.opt.use_stereo
            torch.save(to_save, save_path)

        save_path = os.path.join(save_folder, "{}.pth".format("adam"))
        torch.save(self.model_optimizer.state_dict(), save_path)

    def load_model(self):
        """Load model(s) from disk
        """
        self.opt.load_weights_folder = os.path.expanduser(
            self.opt.load_weights_folder)

        assert os.path.isdir(self.opt.load_weights_folder), \
            "Cannot find folder {}".format(self.opt.load_weights_folder)
        print("loading model from folder {}".format(
            self.opt.load_weights_folder))

        for n in self.opt.models_to_load:
            print("Loading {} weights...".format(n))
            path = os.path.join(self.opt.load_weights_folder,
                                "{}.pth".format(n))
            model_dict = self.models[n].state_dict()
            pretrained_dict = torch.load(path)
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            self.models[n].load_state_dict(model_dict)

        # loading adam state
        optimizer_load_path = os.path.join(self.opt.load_weights_folder,
                                           "adam.pth")
        if os.path.isfile(optimizer_load_path):
            print("Loading Adam weights")
            optimizer_dict = torch.load(optimizer_load_path)
            self.model_optimizer.load_state_dict(optimizer_dict)
        else:
            print("Cannot find Adam weights so Adam is randomly initialized")