Example #1
0
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            #self.encoder.conv1.weight.data[:, 3:, :, :] = 0

            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)

            if (self.in_channels) > 3:
                print(
                    f'modifying input layer to accept {self.in_channels} channels'
                )
                net_encoder_sd = self.encoder.state_dict()
                conv1_weights = net_encoder_sd['conv1.weight']

                c_out, c_in, h, w = conv1_weights.size()
                conv1_mod = torch.zeros(c_out, self.in_channels, h, w)
                conv1_mod[:, :3, :, :] = conv1_weights

                conv1 = self.encoder.conv1
                conv1.in_channels = self.in_channels
                conv1.weight = torch.nn.Parameter(conv1_mod)

                self.encoder.conv1 = conv1

                net_encoder_sd['conv1.weight'] = conv1_mod

                self.encoder.load_state_dict(net_encoder_sd)

        elif pretrained is None:

            if (self.in_channels) > 3:
                print(
                    f'modifying input layer to accept {self.in_channels} channels'
                )
                net_encoder_sd = self.encoder.state_dict()
                conv1_weights = net_encoder_sd['conv1.weight']

                c_out, c_in, h, w = conv1_weights.size()
                conv1_mod = torch.zeros(c_out, self.in_channels, h, w)
                conv1_mod[:, :3, :, :] = conv1_weights

                conv1 = self.encoder.conv1
                conv1.in_channels = self.in_channels
                conv1.weight = torch.nn.Parameter(conv1_mod)

                self.encoder.conv1 = conv1

                net_encoder_sd['conv1.weight'] = conv1_mod

                self.encoder.load_state_dict(net_encoder_sd)

            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    xavier_init(m)
                elif isinstance(m, nn.BatchNorm2d):
                    constant_init(m, 1)
        else:
            raise TypeError(f'"pretrained" must be a str or None.'
                            f'But received {type(pretrained)}')
 def init_weights(self, pretrained=None):
     if isinstance(pretrained, str):
         logger = get_root_logger()
         load_checkpoint(self, pretrained, strict=False, logger=logger)
     elif pretrained is None:
         super(ResGCAEncoder, self).init_weights()
     else:
         raise TypeError('"pretrained" must be a str or None. '
                         f'But received {type(pretrained)}.')
    def init_weights(self, pretrained=None):
        if isinstance(pretrained, str):
            # if pretrained weight is trained on 3-channel images,
            # initialize other channels with zeros
            self.conv1.conv.weight.data[:, 3:, :, :] = 0

            logger = get_root_logger()
            load_checkpoint(self, pretrained, strict=False, logger=logger)
        elif pretrained is None:
            for m in self.modules():
                if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    constant_init(m.weight, 1)
                    constant_init(m.bias, 0)

            # Zero-initialize the last BN in each residual branch, so that the
            # residual branch starts with zeros, and each residual block
            # behaves like an identity. This improves the model by 0.2~0.3%
            # according to https://arxiv.org/abs/1706.02677
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    constant_init(m.conv2.bn.weight, 0)
        else:
            raise TypeError(f'"pretrained" must be a str or None. '
                            f'But received {type(pretrained)}.')