Beispiel #1
0
    def __init__(self, num_classes):
        super(GooGLeNet, self).__init__()
        self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
        self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")

        self.conv2 = Conv2dBlock(64, 64, kernel_size=1)
        self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
        self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")

        self.block3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.block3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")

        self.block4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.block4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.block4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.block4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.block4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same")

        self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.mean = P.ReduceMean(keep_dims=True)
        self.dropout = nn.Dropout(keep_prob=0.8)
        self.flatten = nn.Flatten()
        self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
                                   bias_init=weight_variable())
Beispiel #2
0
    def __init__(self, block, num_classes=100, batch_size=32):
        super(ResNet, self).__init__()
        self.batch_size = batch_size
        self.num_classes = num_classes

        self.conv1 = conv7x7(3, 64, stride=2, padding=0)

        self.bn1 = bn_with_initialize(64)
        self.relu = P.ReLU()
        self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME")

        self.layer1 = MakeLayer0(block,
                                 in_channels=64,
                                 out_channels=256,
                                 stride=1)
        self.layer2 = MakeLayer1(block,
                                 in_channels=256,
                                 out_channels=512,
                                 stride=2)
        self.layer3 = MakeLayer2(block,
                                 in_channels=512,
                                 out_channels=1024,
                                 stride=2)
        self.layer4 = MakeLayer3(block,
                                 in_channels=1024,
                                 out_channels=2048,
                                 stride=2)

        self.pool = P.ReduceMean(keep_dims=True)
        self.squeeze = P.Squeeze(axis=(2, 3))
        self.fc = fc_with_initialize(512 * block.expansion, num_classes)
Beispiel #3
0
    def __init__(self, block, layer_nums, in_channels, out_channels, strides,
                 num_classes, damping, loss_scale, frequency):
        super(ResNet, self).__init__()

        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
            raise ValueError(
                "the length of layer_num, in_channels, out_channels list must be 4!"
            )

        self.conv1 = _conv7x7(3,
                              64,
                              stride=2,
                              damping=damping,
                              loss_scale=loss_scale,
                              frequency=frequency)
        self.bn1 = _bn(64)
        self.relu = P.ReLU()
        self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)

        self.layer1 = self._make_layer(block,
                                       layer_nums[0],
                                       in_channel=in_channels[0],
                                       out_channel=out_channels[0],
                                       stride=strides[0],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency)
        self.layer2 = self._make_layer(block,
                                       layer_nums[1],
                                       in_channel=in_channels[1],
                                       out_channel=out_channels[1],
                                       stride=strides[1],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency)
        self.layer3 = self._make_layer(block,
                                       layer_nums[2],
                                       in_channel=in_channels[2],
                                       out_channel=out_channels[2],
                                       stride=strides[2],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency)
        self.layer4 = self._make_layer(block,
                                       layer_nums[3],
                                       in_channel=in_channels[3],
                                       out_channel=out_channels[3],
                                       stride=strides[3],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency)

        self.mean = P.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()
        self.end_point = _fc(out_channels[3],
                             num_classes,
                             damping=damping,
                             loss_scale=loss_scale,
                             frequency=frequency)
Beispiel #4
0
 def __init__(self, num_classes=10):
     super(DefinedNet, self).__init__()
     self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
     self.bn1 = nn.BatchNorm2d(64)
     self.relu = nn.ReLU()
     self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2)
     self.flatten = nn.Flatten()
     self.fc = nn.Dense(int(56*56*64), num_classes)
Beispiel #5
0
 def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
     super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode)
     self.max_pool = P.MaxPool(ksize=self.kernel_size,
                               strides=self.stride,
                               padding=self.pad_mode)
     self.max_pool_with_arg_max = P.MaxPoolWithArgmax(
         ksize=self.kernel_size, strides=self.stride, padding=self.pad_mode)
     self.is_tbe = context.get_context("device_target") == "Ascend"
    def __init__(self):
        super(Net, self).__init__()

        self.maxpool = P.MaxPoolWithArgmax(pad_mode="same",
                                           kernel_size=3,
                                           strides=2)
        self.x = Parameter(initializer(
            'normal', [1, 64, 112, 112]), name='w')
        self.add = P.TensorAdd()
Beispiel #7
0
 def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
     super(Inception, self).__init__()
     self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1)
     self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red, kernel_size=1),
                                  Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)])
     self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1),
                                  Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)])
     self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same")
     self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
     self.concat = P.Concat(axis=1)
Beispiel #8
0
 def __init__(self, network):
     super(CenterFaceWithNms, self).__init__()
     self.centerface_network = network
     self.config = ConfigCenterface()
     # two type of maxpool self.maxpool2d = nn.MaxPool2d(kernel_size=3, stride=1, pad_mode='same')
     self.maxpool2d = P.MaxPoolWithArgmax(kernel_size=3, strides=1, pad_mode='same')
     self.topk = P.TopK(sorted=True)
     self.reshape = P.Reshape()
     self.print = P.Print()
     self.test_batch = self.config.test_batch_size
     self.k = self.config.K
Beispiel #9
0
    def __init__(self, kernel_size=1, stride=1, pad_mode="VALID", padding=0):
        max_pool = P.MaxPool(ksize=kernel_size,
                             strides=stride,
                             padding=pad_mode)
        self.is_autodiff_backend = False
        if self.is_autodiff_backend:

            # At present, pad mode of max pool is not unified, so it is a temporarily avoided
            pad_mode = validator.check_string('pad_mode', pad_mode.lower(),
                                              ['valid', 'same'])

            max_pool = P.MaxPoolWithArgmax(window=kernel_size,
                                           stride=stride,
                                           pad_mode=pad_mode,
                                           pad=padding)
        super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, padding,
                                        max_pool)
Beispiel #10
0
 def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
     super(MaxPool1d, self).__init__(kernel_size, stride, pad_mode)
     validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
     validator.check_value_type('stride', stride, [int], self.cls_name)
     self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
     validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
     validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
     self.kernel_size = (1, kernel_size)
     self.stride = (1, stride)
     self.max_pool = P.MaxPool(ksize=self.kernel_size,
                               strides=self.stride,
                               padding=self.pad_mode)
     self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
                                                      strides=self.stride,
                                                      padding=self.pad_mode)
     self.shape = F.shape
     self.reduce_mean = P.ReduceMean(keep_dims=True)
     self.expand = P.ExpandDims()
     self.squeeze = P.Squeeze(2)
     self.is_tbe = context.get_context("device_target") == "Ascend"
Beispiel #11
0
     'desc_inputs': [[3, 4, 6, 6], [3, 4, 3, 3], [3, 4, 3, 3]],
     'desc_bprop': [[3, 4, 6, 6]],
     'skip': ['backward']}),
 ('AvgPool', {
     'block': P.AvgPool(ksize=(2, 2), strides=(2, 2), padding="VALID"),
     'desc_inputs': [[100, 3, 28, 28]],
     'desc_bprop': [[100, 3, 14, 14]]}),
 ('AvgPoolGrad', {
     'block': G.AvgPoolGrad(ksize=(2, 2), strides=(2, 2), padding="VALID"),
     'desc_const': [(3, 4, 6, 6)],
     'const_first': True,
     'desc_inputs': [[3, 4, 6, 6]],
     'desc_bprop': [[3, 4, 6, 6]],
     'skip': ['backward']}),
 ('MaxPoolWithArgmax', {
     'block': P.MaxPoolWithArgmax(ksize=2, strides=2),
     'desc_inputs': [[128, 32, 32, 64]],
     'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}),
 ('SoftmaxCrossEntropyWithLogits', {
     'block': P.SoftmaxCrossEntropyWithLogits(),
     'desc_inputs': [[1, 10], [1, 10]],
     'desc_bprop': [[1], [1, 10]],
     'skip': ['backward_exec']}),
 ('Flatten', {
     'block': P.Flatten(),
     'desc_inputs': [[128, 32, 32, 64]],
     'desc_bprop': [[128 * 32 * 8 * 16]]}),
 ('LogSoftmax', {
     'block': P.LogSoftmax(),
     'desc_inputs': [[64, 2]],
     'desc_bprop': [[160, 30522]]}),
    def __init__(self):
        super(Net, self).__init__()

        self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
Beispiel #13
0
     'skip': ['backward']
 }),
 ('Conv2d_ValueError_1', {
     'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {
         'exception': TypeError
     }),
     'desc_inputs': [0],
 }),
 ('Conv2d_ValueError_2', {
     'block': (lambda _: P.Conv2D(3, 4, mode=-2), {
         'exception': ValueError
     }),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_1', {
     'block': (lambda _: P.MaxPoolWithArgmax(padding='sane'), {
         'exception': ValueError
     }),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_2', {
     'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {
         'exception': TypeError
     }),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_3', {
     'block': (lambda _: P.MaxPoolWithArgmax(ksize=-2), {
         'exception': ValueError
     }),
     'desc_inputs': [0],
Beispiel #14
0
     'skip': ['backward']
 }),
 ('Conv2d_ValueError_1', {
     'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {
         'exception': TypeError
     }),
     'desc_inputs': [0],
 }),
 ('Conv2d_ValueError_2', {
     'block': (lambda _: P.Conv2D(3, 4, mode=-2), {
         'exception': ValueError
     }),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_1', {
     'block': (lambda _: P.MaxPoolWithArgmax(pad_mode='sane'), {
         'exception': ValueError
     }),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_2', {
     'block': (lambda _: P.MaxPoolWithArgmax(kernel_size='1'), {
         'exception': TypeError
     }),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_3', {
     'block': (lambda _: P.MaxPoolWithArgmax(kernel_size=-2), {
         'exception': ValueError
     }),
     'desc_inputs': [0],
Beispiel #15
0
 ('ApplyMomentum_Error', {
     'block': (P.ApplyMomentum(), {'exception': TypeError}),
     'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
     'desc_bprop': [[128, 32, 32, 64]],
     'skip': ['backward']
 }),
 ('Conv2d_ValueError_1', {
     'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}),
     'desc_inputs': [0],
 }),
 ('Conv2d_ValueError_2', {
     'block': (lambda _: P.Conv2D(3, 4, mode=-2), {'exception': ValueError}),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_1', {
     'block': (lambda _: P.MaxPoolWithArgmax(padding='sane'), {'exception': ValueError}),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_2', {
     'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': TypeError}),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_3', {
     'block': (lambda _: P.MaxPoolWithArgmax(ksize=-2), {'exception': ValueError}),
     'desc_inputs': [0],
 }),
 ('MaxPoolWithArgmax_ValueError_4', {
     'block': (lambda _: P.MaxPoolWithArgmax(strides=-1), {'exception': ValueError}),
     'desc_inputs': [0],
 }),
 ('FusedBatchNorm_ValueError_1', {
    def __init__(self):
        super(Net, self).__init__()

        self.maxpool = P.MaxPoolWithArgmax(pad_mode="same",
                                                 window=3,
                                                 stride=2)
Beispiel #17
0
     'desc_inputs': [[3, 4, 6, 6], [3, 4, 3, 3], [3, 4, 3, 3]],
     'desc_bprop': [[3, 4, 6, 6]],
     'skip': ['backward']}),
 ('AvgPool', {
     'block': P.AvgPool(ksize=(2, 2), strides=(2, 2), padding="VALID"),
     'desc_inputs': [[100, 3, 28, 28]],
     'desc_bprop': [[100, 3, 14, 14]]}),
 ('AvgPoolGrad', {
     'block': G.AvgPoolGrad(ksize=(2, 2), strides=(2, 2), padding="VALID"),
     'desc_const': [(3, 4, 6, 6)],
     'const_first': True,
     'desc_inputs': [[3, 4, 6, 6]],
     'desc_bprop': [[3, 4, 6, 6]],
     'skip': ['backward']}),
 ('MaxPoolWithArgmax', {
     'block': P.MaxPoolWithArgmax(window=2, stride=2),
     'desc_inputs': [[128, 32, 32, 64]],
     'desc_bprop': [[128, 32, 8, 16], [128, 32, 8, 16]]}),
 ('SoftmaxCrossEntropyWithLogits', {
     'block': P.SoftmaxCrossEntropyWithLogits(),
     'desc_inputs': [[1, 10], [1, 10]],
     'desc_bprop': [[1], [1, 10]],
     'skip': ['backward_exec']}),
 ('Flatten', {
     'block': P.Flatten(),
     'desc_inputs': [[128, 32, 32, 64]],
     'desc_bprop': [[128 * 32 * 8 * 16]]}),
 ('LogSoftmax', {
     'block': P.LogSoftmax(),
     'desc_inputs': [[64, 2]],
     'desc_bprop': [[160, 30522]]}),
 def __init__(self, padding, ksize, strides):
     super(MaxPoolWithArgMax_Net, self).__init__()
     self.maxpool_with_argmax = P.MaxPoolWithArgmax(padding=padding, ksize=ksize, strides=strides)
    # kernel_size != w_shape[2:4]
    ('DepthwiseConv2dNative6', {
        'block': (P.DepthwiseConv2dNative(2, (5, 5)), {
            'exception': ValueError,
            'error_keywords': ['DepthwiseConv2dNative']
        }),
        'desc_inputs': [
            Tensor(np.ones([1, 1, 9, 9]).astype(np.float32)),
            Tensor(np.ones([2, 1, 5, 6]).astype(np.float32))
        ],
        'skip': ['backward']
    }),

    # input is scalar
    ('MaxPoolWithArgmax0', {
        'block': (P.MaxPoolWithArgmax(), {
            'exception': TypeError,
            'error_keywords': ['MaxPoolWithArgmax']
        }),
        'desc_inputs': [5.0],
        'skip': ['backward']
    }),
    # input is Tensor(bool)
    ('MaxPoolWithArgmax1', {
        'block': (P.MaxPoolWithArgmax(), {
            'exception': TypeError,
            'error_keywords': ['MaxPoolWithArgmax']
        }),
        'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_))],
        'skip': ['backward']
    }),
 def __init__(self):
     super(Net_Pool2, self).__init__()
     self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3,
                                            strides=2,
                                            padding="SAME")
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd()
max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple')
transdata = Primitive("TransData")
Transpose = P.Transpose()


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
    def __init__(self):
        super(Net, self).__init__()

        self.maxpool = P.MaxPoolWithArgmax(pad_mode="same",
                                           kernel_size=3,
                                           strides=2)
Beispiel #23
0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.common.dtype as mstype
from mindspore.ops import Primitive
from mindspore.ops import operations as P

addn = P.AddN()
add = P.TensorAdd()
reshape = P.Reshape()
cast = P.Cast()
tuple_getitem = Primitive('tuple_getitem')
max_pool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)


def test_addn_cast(x, y, z):
    mysum = addn((x, y, z))
    res = cast(mysum, mstype.float16)
    return res


def test_addn_with_max_pool(x, y):
    mysum = addn((x, y))
    output = max_pool(mysum)
    res = tuple_getitem(output, 0)
    return res

Beispiel #24
0
 def __init__(self):
     """ ComparisonNet definition """
     super(NetMaxPoolWithArgMax, self).__init__()
     self.max_pool_with_arg_max = P.MaxPoolWithArgmax(padding="valid", ksize=2, strides=1)
Beispiel #25
0
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import operations as P
from mindspore.ops import Primitive
import mindspore as ms


addn = P.AddN()
add = P.TensorAdd()
reshape = P.Reshape()
cast = P.Cast()
tuple_getitem = Primitive('tuple_getitem')
max_pool = P.MaxPoolWithArgmax(pad_mode="same", window=3, stride=2)

def test_addn_cast(x, y, z):
    sum = addn((x, y))
    res = cast(sum, ms.float16)
    return res

def test_addn_with_max_pool(x, y):
    sum = addn((x, y))
    output = max_pool(sum)
    res = tuple_getitem(output, 0)
    return res


def test_shape_add(x1, x2, y1, y2, z1, z2):
    sum1 = add(x1, x2)
Beispiel #26
0
test_cases_for_verify_exception = [
    ('Conv2d_ValueError_1', {
        'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {
            'exception': ValueError
        }),
        'desc_inputs': [0],
    }),
    ('Conv2d_ValueError_2', {
        'block': (lambda _: P.Conv2D(3, 4, mode=-2), {
            'exception': ValueError
        }),
        'desc_inputs': [0],
    }),
    ('MaxPoolWithArgmax_ValueError_1', {
        'block': (lambda _: P.MaxPoolWithArgmax(pad_mode='sane'), {
            'exception': ValueError
        }),
        'desc_inputs': [0],
    }),
    ('MaxPoolWithArgmax_ValueError_2', {
        'block': (lambda _: P.MaxPoolWithArgmax(data_mode=2), {
            'exception': ValueError
        }),
        'desc_inputs': [0],
    }),
    ('MaxPoolWithArgmax_ValueError_3', {
        'block': (lambda _: P.MaxPoolWithArgmax(ceil_mode=2), {
            'exception': ValueError
        }),
        'desc_inputs': [0],
Beispiel #27
0
    def __init__(self, block, layer_nums, in_channels, out_channels, strides,
                 num_classes, damping, loss_scale, frequency, batch_size,
                 resnet_d, init_new):
        super(ResNetNoBN, self).__init__()

        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
            raise ValueError(
                "the length of layer_num, in_channels, out_channels list must be 4!"
            )

        self.conv1 = _conv7x7(3,
                              64,
                              stride=2,
                              damping=damping,
                              loss_scale=loss_scale,
                              frequency=frequency,
                              batch_size=batch_size)
        # self.bn1 = _bn(64)
        self.relu = P.ReLU()
        self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)

        self.layer1 = self._make_layer(block,
                                       layer_nums[0],
                                       in_channel=in_channels[0],
                                       out_channel=out_channels[0],
                                       stride=strides[0],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency,
                                       batch_size=batch_size,
                                       resnet_d=resnet_d)
        self.layer2 = self._make_layer(block,
                                       layer_nums[1],
                                       in_channel=in_channels[1],
                                       out_channel=out_channels[1],
                                       stride=strides[1],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency,
                                       batch_size=batch_size,
                                       resnet_d=resnet_d)
        self.layer3 = self._make_layer(block,
                                       layer_nums[2],
                                       in_channel=in_channels[2],
                                       out_channel=out_channels[2],
                                       stride=strides[2],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency,
                                       batch_size=batch_size,
                                       resnet_d=resnet_d)
        self.layer4 = self._make_layer(block,
                                       layer_nums[3],
                                       in_channel=in_channels[3],
                                       out_channel=out_channels[3],
                                       stride=strides[3],
                                       damping=damping,
                                       loss_scale=loss_scale,
                                       frequency=frequency,
                                       batch_size=batch_size,
                                       resnet_d=resnet_d)

        self.mean = P.ReduceMean(keep_dims=True)
        self.flatten = nn.Flatten()
        self.end_point = _fc(out_channels[3],
                             num_classes,
                             damping=damping,
                             loss_scale=loss_scale,
                             frequency=frequency,
                             batch_size=batch_size)

        L = sum(layer_nums)  # e.g., resnet101 has 33 residual branches
        for name, cell in self.cells_and_names():
            if isinstance(cell, ResidualBlockNoBN):
                # cell.conv3.weight.default_input = initializer('zeros', cell.conv3.weight.default_input.shape()).to_tensor()
                fixup_m = 3  # !!!
                v = math.pow(L, -1 / (2 * fixup_m - 2))
                cell.conv1.weight.default_input = Tensor(
                    cell.conv1.weight.default_input.asnumpy() * v,
                    cell.conv1.weight.default_input.dtype())
                cell.conv2.weight.default_input = Tensor(
                    cell.conv2.weight.default_input.asnumpy() * v,
                    cell.conv2.weight.default_input.dtype())
                if init_new:
                    cell.conv3.weight.default_input = Tensor(
                        cell.conv3.weight.default_input.asnumpy() * v,
                        cell.conv3.weight.default_input.dtype())
                else:
                    cell.conv3.weight.default_input = initializer(
                        'zeros',
                        cell.conv3.weight.default_input.shape()).to_tensor()
            elif isinstance(cell, ResidualBlockNoBN):
                cell.conv2.weight.default_input = initializer(
                    'zeros',
                    cell.conv2.weight.default_input.shape()).to_tensor()
                fixup_m = 2  # !!!
                v = math.pow(L, -1 / (2 * fixup_m - 2))
                cell.conv1.weight.default_input = Tensor(
                    cell.conv1.weight.default_input.asnumpy() * v,
                    cell.conv1.weight.default_input.dtype())