Exemplo n.º 1
0
    def __init__(self, inplanes):
        super(RefinementNet, self).__init__()

        self.conv1 = nn.Sequential(
            convbn_2d_lrelu(inplanes, 32, kernel_size=3, stride=1, pad=1),
            convbn_2d_lrelu(32, 32, kernel_size=3, stride=1, pad=1,
                            dilation=1),
            convbn_2d_lrelu(32, 32, kernel_size=3, stride=1, pad=1,
                            dilation=1),
            convbn_2d_lrelu(32, 16, kernel_size=3, stride=1, pad=2,
                            dilation=2),
            convbn_2d_lrelu(16, 16, kernel_size=3, stride=1, pad=4,
                            dilation=4),
            convbn_2d_lrelu(16, 16, kernel_size=3, stride=1, pad=1,
                            dilation=1))

        self.classif1 = nn.Conv2d(16,
                                  1,
                                  kernel_size=3,
                                  padding=1,
                                  stride=1,
                                  bias=False)
        self.relu = nn.ReLU(inplace=True)

        self.weight_init()
Exemplo n.º 2
0
 def __init__(self, params=None):
     super(GuideNN, self).__init__()
     self.params = params
     self.conv1 = convbn_2d_lrelu(32, 16, 1, 1, 0)
     self.conv2 = convbn_2d_Tanh(16, 1, 1, 1, 0)
Exemplo n.º 3
0
    def __init__(self):
        super(DeepPruner, self).__init__()

        self.scale = args.cost_aggregator_scale
        self.max_disp = args.max_disp // self.scale
        self.mode = args.mode

        self.patch_match_args = args.patch_match_args
        self.patch_match_sample_count = self.patch_match_args.sample_count
        self.patch_match_iteration_count = self.patch_match_args.iteration_count
        self.patch_match_propagation_filter_size = self.patch_match_args.propagation_filter_size

        self.post_CRP_sample_count = args.post_CRP_sample_count
        self.post_CRP_sampler_type = args.post_CRP_sampler_type
        hourglass_inplanes = args.hourglass_inplanes

        #   refinement input features are composed of:
        #                                       left image low level features +
        #                                       CA output features + CA output disparity

        if self.scale == 8:
            from models.feature_extractor_fast import feature_extraction
            refinement_inplanes_1 = args.feature_extractor_refinement_level_1_outplanes + 1
            self.refinement_net1 = RefinementNet(refinement_inplanes_1)
        else:
            from models.feature_extractor_best import feature_extraction

        refinement_inplanes = args.feature_extractor_refinement_level_outplanes + self.post_CRP_sample_count + 2 + 1
        self.refinement_net = RefinementNet(refinement_inplanes)

        # cost_aggregator_inplanes are composed of:
        #                            left and right image features from feature_extractor (ca_level) +
        #                            features from min/max predictors +
        #                            min_disparity + max_disparity + disparity_samples

        cost_aggregator_inplanes = 2 * (
            args.feature_extractor_ca_level_outplanes +
            self.patch_match_sample_count + 2) + 1
        self.cost_aggregator = CostAggregator(cost_aggregator_inplanes,
                                              hourglass_inplanes)

        self.feature_extraction = feature_extraction()
        self.min_disparity_predictor = MinDisparityPredictor(
            hourglass_inplanes)
        self.max_disparity_predictor = MaxDisparityPredictor(
            hourglass_inplanes)
        self.spatial_transformer = SpatialTransformer()
        self.patch_match = PatchMatch(self.patch_match_propagation_filter_size)
        self.uniform_sampler = UniformSampler()

        # Confidence Range Predictor(CRP) input features are composed of:
        #                            left and right image features from feature_extractor (ca_level) +
        #                            disparity_samples

        CRP_feature_count = 2 * args.feature_extractor_ca_level_outplanes + 1
        self.dres0 = nn.Sequential(
            convbn_3d_lrelu(CRP_feature_count, 64, 3, 1, 1),
            convbn_3d_lrelu(64, 32, 3, 1, 1))

        self.dres1 = nn.Sequential(
            convbn_3d_lrelu(32, 32, 3, 1, 1),
            convbn_3d_lrelu(32, hourglass_inplanes, 3, 1, 1))

        self.min_disparity_conv = conv_relu(1, 1, 5, 1, 2)
        self.max_disparity_conv = conv_relu(1, 1, 5, 1, 2)
        self.ca_disparity_conv = conv_relu(1, 1, 5, 1, 2)

        self.ca_features_conv = convbn_2d_lrelu(self.post_CRP_sample_count + 2,
                                                self.post_CRP_sample_count + 2,
                                                5,
                                                1,
                                                2,
                                                dilation=1,
                                                bias=True)
        self.min_disparity_features_conv = convbn_2d_lrelu(
            self.patch_match_sample_count + 2,
            self.patch_match_sample_count + 2,
            5,
            1,
            2,
            dilation=1,
            bias=True)
        self.max_disparity_features_conv = convbn_2d_lrelu(
            self.patch_match_sample_count + 2,
            self.patch_match_sample_count + 2,
            5,
            1,
            2,
            dilation=1,
            bias=True)