def __init__(self, backbone_name, num_classes, platform="Ascend"):
        self.backbone_name = backbone_name
        backbone = backbones.__dict__[self.backbone_name](platform=platform)
        out_channels = backbone.get_out_channels()
        head = heads.CommonHead(num_classes=num_classes,
                                out_channels=out_channels)
        super(Resnet, self).__init__(backbone, head)

        default_recurisive_init(self)

        for cell in self.cells_and_names():
            if isinstance(cell, nn.Conv2d):
                cell.weight.set_data(
                    init.initializer(
                        KaimingNormal(a=math.sqrt(5),
                                      mode='fan_out',
                                      nonlinearity='relu'), cell.weight.shape,
                        cell.weight.dtype))
            elif isinstance(cell, nn.BatchNorm2d):
                cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
                cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        for cell in self.cells_and_names():
            if isinstance(cell, backbones.resnet.Bottleneck):
                cell.bn3.gamma.set_data(
                    init.initializer('zeros', cell.bn3.gamma.shape))
            elif isinstance(cell, backbones.resnet.BasicBlock):
                cell.bn2.gamma.set_data(
                    init.initializer('zeros', cell.bn2.gamma.shape))
示例#2
0
    def __init__(self, num_classes):
        super(DenseNet121, self).__init__()
        self.backbone = _densenet121()
        out_channels = self.backbone.get_out_channels()
        self.head = CommonHead(num_classes, out_channels)

        default_recurisive_init(self)
        for _, cell in self.cells_and_names():
            if isinstance(cell, nn.Conv2d):
                cell.weight.set_data(
                    init.initializer(
                        KaimingNormal(a=math.sqrt(5),
                                      mode='fan_out',
                                      nonlinearity='relu'), cell.weight.shape,
                        cell.weight.dtype))
            elif isinstance(cell, nn.BatchNorm2d):
                cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
                cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
            elif isinstance(cell, nn.Dense):
                cell.bias.set_data(init.initializer('zeros', cell.bias.shape))