Ejemplo n.º 1
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 2
0
 def init_weights(self):
     logger = get_root_logger()
     if self.init_cfg is None:
         logger.warn(f'No pre-trained weights for '
                     f'{self.__class__.__name__}, '
                     f'training start from scratch')
         for m in self.modules():
             if isinstance(m, nn.Linear):
                 trunc_normal_init(m, std=.02, bias=0.)
             elif isinstance(m, nn.LayerNorm):
                 constant_init(m.bias, 0)
                 constant_init(m.weight, 1.0)
             elif isinstance(m, nn.Conv2d):
                 fan_out = m.kernel_size[0] * m.kernel_size[
                     1] * m.out_channels
                 fan_out //= m.groups
                 normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
                 if m.bias is not None:
                     constant_init(m.bias, 0)
             elif isinstance(m, AbsolutePositionEmbedding):
                 m.init_weights()
     else:
         assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                               f'specify `Pretrained` in ' \
                                               f'`init_cfg` in ' \
                                               f'{self.__class__.__name__} '
         checkpoint = _load_checkpoint(self.init_cfg.checkpoint,
                                       logger=logger,
                                       map_location='cpu')
         logger.warn(f'Load pre-trained model for '
                     f'{self.__class__.__name__} from original repo')
         if 'state_dict' in checkpoint:
             state_dict = checkpoint['state_dict']
         elif 'model' in checkpoint:
             state_dict = checkpoint['model']
         else:
             state_dict = checkpoint
         if self.convert_weights:
             # Because pvt backbones are not supported by mmcls,
             # so we need to convert pre-trained weights to match this
             # implementation.
             state_dict = pvt_convert(state_dict)
         load_state_dict(self, state_dict, strict=False, logger=logger)
Ejemplo n.º 3
0
 def init_weights(self, pretrained=None):
     if isinstance(pretrained, str):
         logger = logging.getLogger()
         load_checkpoint(self, pretrained, strict=False, logger=logger)
         for m in self.modules():
             if isinstance(m, nn.Conv2d) and hasattr(
                     m, 'zero_init') and m.zero_init:
                 constant_init(m, 0)
     elif pretrained is None:
         for m in self.modules():
             if isinstance(m, nn.Conv2d):
                 if hasattr(m, 'zero_init') and m.zero_init:
                     constant_init(m, 0)
                 else:
                     kaiming_init(m)
             elif isinstance(m, nn.BatchNorm2d):
                 constant_init(m, 1)
     else:
         raise TypeError('pretrained must be a str or None')
Ejemplo n.º 4
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            print("load pretrained weights from {}".format(pretrained))
            model_dict = torch.load(pretrained)
            if 'state_dict' in model_dict:
                model_dict = model_dict['state_dict']
            state_dict = {}
            for k,v in model_dict.items():
                state_dict[k[9:]] = v
            # print(state_dict.keys())
            missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
            print("missing kyes:", missing_keys)
            print("unexpected keys:", unexpected_keys)            
#         load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.features.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)
                elif isinstance(m, nn.Linear):
                    normal_init(m, std=0.01)
        else:
            raise TypeError('pretrained must be a str or None')

        for m in self.extra.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
                # xavier_init(m, distribution='uniform')
            elif isinstance(m, nn.BatchNorm2d):
                constant_init(m, 1)
            elif isinstance(m, nn.Linear):
                normal_init(m, std=0.01)

        constant_init(self.l2_norm, self.l2_norm.scale)
Ejemplo n.º 5
0
    def _init_weights(self, pretrained=None):
        """Initiate the parameters either from existing checkpoint or from
        scratch.

        Args:
            pretrained (str | None): The path of the pretrained weight. Will
                override the original `pretrained` if set. The arg is added to
                be compatible with mmdet. Default: None.
        """
        if pretrained:
            self.pretrained = pretrained
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            logger.info(f'load model from: {self.pretrained}')

            if self.pretrained2d:
                # Inflate 2D model into 3D model.
                self.inflate_weights(logger)

            else:
                # Directly load 3D model.
                load_checkpoint(self,
                                self.pretrained,
                                strict=False,
                                logger=logger)

        elif self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv3d):
                    kaiming_init(m)
                elif isinstance(m, _BatchNorm):
                    constant_init(m, 1)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck3d):
                        constant_init(m.conv3.bn, 0)
                    elif isinstance(m, BasicBlock3d):
                        constant_init(m.conv2.bn, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 6
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        super().init_weights(pretrained)
        if pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
Ejemplo n.º 7
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.

        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)
            model_dict = self.state_dict()
            pretrained_dict = checkpoint['state_dict']
            missing_keys = [k for k in model_dict if k not in pretrained_dict]
            unexpected_keys = [k for k in pretrained_dict if k not in model_dict and 'eta' not in k]
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            self.load_state_dict(model_dict)
            if unexpected_keys:
                print(f'unexpected key in source state_dict: {", ".join(unexpected_keys)}\n')
            if missing_keys:
                print(f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottleneck) and hasattr(
                            m.conv2, 'conv_offset'):
                        constant_init(m.conv2.conv_offset, 0)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 8
0
    def init_weights(self, pretrained=None):
        """Initialize the weights in backbone.
        Args:
            pretrained (str, optional): Path to pre-trained weights.
                Defaults to None.
        """
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
                    elif isinstance(m, BasicBlock):
                        constant_init(m.norm2, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 9
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            checkpoint = model_zoo.load_url(
                pretrained, map_location=lambda storage, loc: storage)
            self.load_state_dict(checkpoint, strict=False)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, Bottle2neck) and hasattr(
                            m, 'conv2_offset'):
                        constant_init(m.conv2_offset, 0)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottle2neck):
                        constant_init(m.norm3, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 10
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)
                elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                    constant_init(m, 1)

            if self.dcn is not None:
                for m in self.modules():
                    if isinstance(m, _Bottleneck):
                        for conv2 in m.convs:
                            if hasattr(conv2, 'conv_offset'):
                                constant_init(conv2.conv_offset, 0)

            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, _Bottleneck):
                        constant_init(m.norm3, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 11
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            # if pretrained weight is trained on 3-channel images,
            # initialize other channels with zeros
            self.conv1.conv.weight.data[:, 3:, :, :] = 0

            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    constant_init(m.weight, 1)
                    constant_init(m.bias, 0)

            # Zero-initialize the last BN in each residual branch, so that the
            # residual branch starts with zeros, and each residual block
            # behaves like an identity. This improves the model by 0.2~0.3%
            # according to https://arxiv.org/abs/1706.02677
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    constant_init(m.conv2.bn.weight, 0)
        else:
            raise TypeError(f'"pretrained" must be a str or None. '
                            f'But received {type(pretrained)}.')
Ejemplo n.º 12
0
 def init_weights(self):
     constant_init(self.temporal_fc, val=0, bias=0)
Ejemplo n.º 13
0
 def init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Conv2d):
             normal_init(m, std=0.03)
         elif isinstance(m, nn.BatchNorm2d):
             constant_init(m, 1)
Ejemplo n.º 14
0
 def init_weights(self):
     kaiming_init(self.conv, nonlinearity='relu')
     if self.with_bn:
         constant_init(self.bn, 1, bias=0)
Ejemplo n.º 15
0
 def init_weights(self):
     """Initialize the weights."""
     if self.rfp_inplanes:
         constant_init(self.rfp_conv, 0)
Ejemplo n.º 16
0
 def init_weights(self):
     xavier_init(self.guidance_conv, distribution='uniform')
     xavier_init(self.out_conv.conv, distribution='uniform')
     constant_init(self.out_conv.norm, 1e-3)
Ejemplo n.º 17
0
 def init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d):
             kaiming_init(m)
         elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d):
             constant_init(m, 1.0, 0.0)
Ejemplo n.º 18
0
    def init_weights(self):
        """init weight"""
        if isinstance(self.pretrained, str):
            logger = get_root_logger()
            logger.info("load model from: {}".format(self.pretrained))
            if self.pretrained2d:
                resnet2d = ResNet(
                    self.depth,
                    avg_down=self.avg_down,
                    avd=self.avd,
                    avd_first=self.avd_first,
                    deep_stem=self.deep_stem,
                    stem_width=self.stem_width)
                load_checkpoint(resnet2d, self.pretrained, map_location='cpu',
                                strict=False, logger=logger)
                for name, module in self.named_modules():
                    if isinstance(module, NonLocalModule):
                        module.init_weights()
                    elif isinstance(module, nn.Conv3d) and rhasattr(
                            resnet2d, name):
                        new_weight = rgetattr(
                            resnet2d, name).weight.data.unsqueeze(2).expand_as(
                                module.weight) / module.weight.data.shape[2]
                        module.weight.data.copy_(new_weight)
                        logger.info(
                            "{}.weight loaded from weights file into {}".
                            format(name, new_weight.shape))

                        if hasattr(module, 'bias') and module.bias is not None:
                            new_bias = rgetattr(resnet2d, name).bias.data
                            module.bias.data.copy_(new_bias)
                            logger.info(
                                "{}.bias loaded from weights file into {}".
                                format(name, new_bias.shape))

                    elif isinstance(module, _BatchNorm) and rhasattr(
                            resnet2d, name):
                        for attr in [
                                'weight', 'bias', 'running_mean', 'running_var'
                        ]:
                            logger.info(
                                "{}.{} loaded from weights file into {}"
                                .format(
                                    name, attr, getattr(
                                        rgetattr(resnet2d, name), attr).shape))
                            setattr(module, attr, getattr(
                                rgetattr(resnet2d, name), attr))
            else:
                load_checkpoint(
                    self, self.pretrained, strict=False, logger=logger)

        elif self.pretrained is None:
            for m in self.modules():
                if isinstance(m, nn.Conv3d):
                    kaiming_init(m)
                elif isinstance(m, _BatchNorm):
                    constant_init(m, 1)
            if self.zero_init_residual:
                for m in self.modules():
                    if isinstance(m, Bottleneck):
                        constant_init(m.norm3, 0)
        else:
            raise TypeError('pretrained must be a str or None')
Ejemplo n.º 19
0
 def init_weights(self):
     for m in self.fcs.modules():
         if isinstance(m, nn.Linear):
             xavier_init(m)
     constant_init(self.fc_logits, 0.001)
Ejemplo n.º 20
0
 def init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Conv2d):
             xavier_init(m, distribution='uniform')
         elif isinstance(m, nn.BatchNorm2d):
             constant_init(m, 1)
Ejemplo n.º 21
0
 def last_zero_init(self, m):
     if isinstance(m, nn.Sequential):
         constant_init(m[-1], val=0)
     else:
         constant_init(m, val=0)
Ejemplo n.º 22
0
 def init_weights(self, pretrained=None):
     for m in self.modules():
         if isinstance(m, nn.Conv2d):
             kaiming_init(m)
         elif isinstance(m, nn.BatchNorm2d):
             constant_init(m, 1)
Ejemplo n.º 23
0
 def _init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Conv2d):
             normal_init(m, 0, 0.01)
     if self.zero_init_offset:
         constant_init(self.spatial_conv_offset, 0)
Ejemplo n.º 24
0
    def init_weights(self):
        logger = get_root_logger()
        if self.init_cfg is None:
            logger.warn(f'No pre-trained weights for '
                        f'{self.__class__.__name__}, '
                        f'training start from scratch')
            if self.use_abs_pos_embed:
                trunc_normal_(self.absolute_pos_embed, std=0.02)
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, 1.0)
        else:
            assert 'checkpoint' in self.init_cfg, f'Only support ' \
                                                  f'specify `Pretrained` in ' \
                                                  f'`init_cfg` in ' \
                                                  f'{self.__class__.__name__} '
            ckpt = _load_checkpoint(self.init_cfg.checkpoint,
                                    logger=logger,
                                    map_location='cpu')
            if 'state_dict' in ckpt:
                _state_dict = ckpt['state_dict']
            elif 'model' in ckpt:
                _state_dict = ckpt['model']
            else:
                _state_dict = ckpt
            if self.convert_weights:
                # supported loading weight from original repo,
                _state_dict = swin_converter(_state_dict)

            state_dict = OrderedDict()
            for k, v in _state_dict.items():
                if k.startswith('backbone.'):
                    state_dict[k[9:]] = v

            # strip prefix of state_dict
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in state_dict.items()}

            # reshape absolute position embedding
            if state_dict.get('absolute_pos_embed') is not None:
                absolute_pos_embed = state_dict['absolute_pos_embed']
                N1, L, C1 = absolute_pos_embed.size()
                N2, C2, H, W = self.absolute_pos_embed.size()
                if N1 != N2 or C1 != C2 or L != H * W:
                    logger.warning('Error in loading absolute_pos_embed, pass')
                else:
                    state_dict['absolute_pos_embed'] = absolute_pos_embed.view(
                        N2, H, W, C2).permute(0, 3, 1, 2).contiguous()

            # interpolate position bias table if needed
            relative_position_bias_table_keys = [
                k for k in state_dict.keys()
                if 'relative_position_bias_table' in k
            ]
            for table_key in relative_position_bias_table_keys:
                table_pretrained = state_dict[table_key]
                table_current = self.state_dict()[table_key]
                L1, nH1 = table_pretrained.size()
                L2, nH2 = table_current.size()
                if nH1 != nH2:
                    logger.warning(f'Error in loading {table_key}, pass')
                elif L1 != L2:
                    S1 = int(L1**0.5)
                    S2 = int(L2**0.5)
                    table_pretrained_resized = F.interpolate(
                        table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),
                        size=(S2, S2),
                        mode='bicubic')
                    state_dict[table_key] = table_pretrained_resized.view(
                        nH2, L2).permute(1, 0).contiguous()

            # load state_dict
            self.load_state_dict(state_dict, False)
Ejemplo n.º 25
0
 def init_weights(self):
     nonlinearity = 'relu' if self.activation is None else self.activation
     kaiming_init(self.conv, nonlinearity=nonlinearity)
     if self.with_norm:
         constant_init(self.norm, 1, bias=0)
Ejemplo n.º 26
0
 def init_weights(self):
     """Initialize weights of shared MLP layers."""
     self.scorenet.init_weights()
     if self.bn is not None:
         constant_init(self.bn, val=1)
 def init_weights(self):
     """Initialize weight of later layer."""
     if self.out_project is not None:
         if not isinstance(self.out_project, ConvModule):
             constant_init(self.out_project, 0)
Ejemplo n.º 28
0
    def init_weights(self):
        for m in self.upsample_layers.modules():
            if isinstance(m, nn.BatchNorm2d):
                constant_init(m, 1)

        for m in self.shortcut_layers.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)

        bias_cls = bias_init_with_prob(0.01)
        for hm in [self.hm_b1, self.hm_b2]:
            for m in hm.modules():
                if isinstance(m, nn.Conv2d):
                    if self.all_kaiming:
                        kaiming_init(m)
                    else:
                        normal_init(m, std=0.01)
            normal_init(hm[-1], std=0.01, bias=bias_cls)

        for wh in [self.wh_b1, self.wh_b2]:
            for m in wh.modules():
                if isinstance(m, nn.Conv2d):
                    if self.all_kaiming:
                        kaiming_init(m)
                    else:
                        normal_init(m, std=0.001)

        if self.mdcn_before_s8 or self.ind_mdcn_for_s8:
            for m in self.mdcn_s8_layer.modules():
                if isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)

        if self.conv_before_s8:
            for m in self.conv_s8_layer.modules():
                if isinstance(m, nn.Conv2d):
                    if self.all_kaiming:
                        kaiming_init(m)
                    else:
                        normal_init(m, std=0.01)
                if isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)

        if self.with_score_loss:
            for m in self.hm_bns:
                constant_init(m, 1)

        if self.conv_exchage:
            for m in self.conv_ex:
                kaiming_init(m)

        if self.extra_shortcut_cfg:
            for m in self.extra_shortcut_layer.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_init(m)

        for m in self.modules():
            if isinstance(m, ModulatedDeformConvPack):
                if hasattr(m, 'conv_offset_mask'):
                    constant_init(m.conv_offset_mask, 0)
                else:
                    constant_init(m.conv_offset, 0)
Ejemplo n.º 29
0
 def init_weights(self):
     for m in self.modules():
         if isinstance(m, nn.Conv2d):
             kaiming_init(m)
         elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
             constant_init(m, 1)
Ejemplo n.º 30
0
    def __init__(
        self,
        nIn,
        spatial_pooltype="max",
        spatial_topk=1,
        region_softpool=False,
        num_regions=8,
        region_topk=8,
        rot_dim=4,
        mask_attention_type="none",
    ):  # NOTE: not used!!!
        """
        Args:
            nIn: input feature channel
            spatial_pooltype: max | soft
            spatial_topk: 1
            region_softpool (bool): if not softpool, just flatten
        """
        super().__init__()
        self.mask_attention_type = mask_attention_type
        self.spatial_pooltype = spatial_pooltype
        self.spatial_topk = spatial_topk
        self.region_softpool = region_softpool
        self.num_regions = num_regions
        self.region_topk = region_topk
        # -----------------------------------

        self.conv1 = torch.nn.Conv1d(nIn, 128, 1)
        self.conv2 = torch.nn.Conv1d(128, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 128, 1)

        if self.spatial_pooltype == "topk":
            self.conv_topk = nn.Conv2d(128,
                                       128,
                                       kernel_size=(1, self.spatial_topk),
                                       stride=(1, 1))

        if not region_softpool:
            in_dim = 128 * num_regions
        else:
            in_dim = 128 * region_topk
            self.conv_sp = nn.Conv2d(128,
                                     128,
                                     kernel_size=(1, 128),
                                     stride=(1, 1))

        # self.fc1 = nn.Linear(in_dim + 128, 512)  # NOTE: 128 for extents feature
        self.fc1 = nn.Linear(in_dim, 512)  # NOTE: no extent feature
        self.fc2 = nn.Linear(512, 256)
        self.fc_r = nn.Linear(256, rot_dim)  # quat or rot6d
        # TODO: predict centroid and z separately
        self.fc_t = nn.Linear(256, 3)
        self.act = nn.LeakyReLU(0.1, inplace=True)

        # feature for extent
        # self.extent_fc1 = nn.Linear(3, 64)
        # self.extent_fc2 = nn.Linear(64, 128)

        # init ------------------------------------
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Conv1d)):
                normal_init(m, std=0.001)
            elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                constant_init(m, 1)
            elif isinstance(m, nn.ConvTranspose2d):
                normal_init(m, std=0.001)
            elif isinstance(m, nn.Linear):
                normal_init(m, std=0.001)
        normal_init(self.fc_r, std=0.01)
        normal_init(self.fc_t, std=0.01)