def __init__(self, feature_extractor, head, norm_layer=nn.BatchNorm2d, use_rgb_conv=True):
        super(DistMapsModel, self).__init__()

        if use_rgb_conv:
            self.rgb_conv = nn.Sequential(
                nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
                nn.LeakyReLU(negative_slope=0.2),
                norm_layer(8),
                nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
            )
        else:
            self.rgb_conv = None

        self.dist_maps = DistMaps(norm_radius=260, spatial_scale=1.0)
        self.feature_extractor = feature_extractor
        self.head = head
    def __init__(self, feature_extractor, max_interactive_points=10, use_rgb_conv=True, with_aux_output=False,
                 norm_layer=nn.BatchNorm2d):
        super(DistMapsHRNetModel, self).__init__()
        self.with_aux_output = with_aux_output

        if use_rgb_conv:
            self.rgb_conv = nn.Sequential(
                nn.Conv2d(in_channels=5, out_channels=8, kernel_size=1),
                nn.LeakyReLU(negative_slope=0.2),
                norm_layer(8),
                nn.Conv2d(in_channels=8, out_channels=3, kernel_size=1),
            )
        else:
            self.rgb_conv = None

        self.dist_maps = DistMaps(norm_radius=260, max_interactive_points=max_interactive_points,
                                  spatial_scale=1.0)
        self.feature_extractor = feature_extractor
    def __init__(self,
                 use_rgb_conv=True,
                 with_aux_output=False,
                 norm_radius=260,
                 use_disks=False,
                 cpu_dist_maps=False,
                 clicks_groups=None,
                 with_prev_mask=False,
                 use_leaky_relu=False,
                 binary_prev_mask=False,
                 conv_extend=False,
                 norm_layer=nn.BatchNorm2d,
                 norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
        super().__init__()
        self.with_aux_output = with_aux_output
        self.clicks_groups = clicks_groups
        self.with_prev_mask = with_prev_mask
        self.binary_prev_mask = binary_prev_mask
        self.normalization = BatchImageNormalize(norm_mean_std[0],
                                                 norm_mean_std[1])

        self.coord_feature_ch = 2
        if clicks_groups is not None:
            self.coord_feature_ch *= len(clicks_groups)

        if self.with_prev_mask:
            self.coord_feature_ch += 1

        if use_rgb_conv:
            rgb_conv_layers = [
                nn.Conv2d(in_channels=3 + self.coord_feature_ch,
                          out_channels=6 + self.coord_feature_ch,
                          kernel_size=1),
                norm_layer(6 + self.coord_feature_ch),
                nn.LeakyReLU(negative_slope=0.2)
                if use_leaky_relu else nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=6 + self.coord_feature_ch,
                          out_channels=3,
                          kernel_size=1)
            ]
            self.rgb_conv = nn.Sequential(*rgb_conv_layers)
        elif conv_extend:
            self.rgb_conv = None
            self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch,
                                            out_channels=64,
                                            kernel_size=3,
                                            stride=2,
                                            padding=1)
            self.maps_transform.apply(LRMult(0.1))
        else:
            self.rgb_conv = None
            mt_layers = [
                nn.Conv2d(in_channels=self.coord_feature_ch,
                          out_channels=16,
                          kernel_size=1),
                nn.LeakyReLU(negative_slope=0.2)
                if use_leaky_relu else nn.ReLU(inplace=True),
                nn.Conv2d(in_channels=16,
                          out_channels=64,
                          kernel_size=3,
                          stride=2,
                          padding=1),
                ScaleLayer(init_value=0.05, lr_mult=1)
            ]
            self.maps_transform = nn.Sequential(*mt_layers)

        if self.clicks_groups is not None:
            self.dist_maps = nn.ModuleList()
            for click_radius in self.clicks_groups:
                self.dist_maps.append(
                    DistMaps(norm_radius=click_radius,
                             spatial_scale=1.0,
                             cpu_mode=cpu_dist_maps,
                             use_disks=use_disks))
        else:
            self.dist_maps = DistMaps(norm_radius=norm_radius,
                                      spatial_scale=1.0,
                                      cpu_mode=cpu_dist_maps,
                                      use_disks=use_disks)