def testCase1(self):
        start_disp = -4
        dilation = 2
        alpha = 1.0
        normalize = True
        max_disp = 9
        h, w = 2, 2

        d = (max_disp + dilation - 1) // dilation

        cfg = Config(
            dict(model=dict(
                disp_predictor=dict(
                    type=self.pred_type,
                    # the maximum disparity of disparity search range
                    max_disp=max_disp,
                    # disparity sample radius
                    radius=self.radius,
                    # the start disparity of disparity search range
                    start_disp=start_disp,
                    # the step between near disparity sample
                    dilation=dilation,
                    # the step between near disparity sample when local sampling
                    radius_dilation=self.radius_dilation,
                    # the temperature coefficient of soft argmin
                    alpha=alpha,
                    # whether normalize the estimated cost volume
                    normalize=normalize,
                ), )))

        cfg.model.update(
            disp_predictor=kick_out_none_keys(cfg.model.disp_predictor))

        cost = torch.ones(1, d, h, w).to(self.device)
        cost.requires_grad = True
        print('*' * 60)
        print('Cost volume:')
        print(cost)

        disp_predictor = build_disp_predictor(cfg).to(self.device)
        print(disp_predictor)
        disp = disp_predictor(cost)
        print('*' * 60)
        print('Regressed disparity map :')
        print(disp)

        # soft argmin
        if self.pred_type == 'DEFAULT':
            print('*' * 60)
            print('Test directly providing disparity samples')

            end_disp = start_disp + max_disp - 1

            # generate disparity samples
            disp_samples = torch.linspace(start_disp, end_disp, d).repeat(1, h, w, 1).\
                                                permute(0, 3, 1, 2).contiguous().to(cost.device)
            disp = disp_predictor(cost, disp_samples)
            print('Regressed disparity map :')
            print(disp)
    def __init__(self, cfg):
        super(GeneralizedStereoModel, self).__init__()
        self.cfg = cfg.copy()
        self.max_disp = cfg.model.max_disp
        self.scale = cfg.model.backbone.scale

        self.backbone = build_backbone(cfg)
        self.cost_processor = build_cost_processor(cfg)

        # confidence measurement network
        self.cmn = None
        if 'cmn' in cfg.model:
            self.cmn = build_cmn(cfg)

        self.disp_predictor = build_disp_predictor(cfg)
        self.loss_evaluator = make_gsm_loss_evaluator(cfg)