Exemplo n.º 1
0
    def __init__(self, batchnorm, dropout):
        super(CosineNet, self).__init__()
        layers = []
        if batchnorm:
            layers.append(nn.BatchNorm2d(INPUT_DIM))

        # initialize hidden layers
        for l_n in range(N_LAYERS):
            in_channels = HIDDEN_DIM if l_n > 0 else INPUT_DIM
            # Use 1x1Conv instead of Dense, which coordinate better with BatchNorm2d opetator;
            conv = nn.Conv2d(in_channels,
                             HIDDEN_DIM,
                             kernel_size=1,
                             pad_mode='valid',
                             has_bias=True,
                             weight_init=Normal(0.01))
            layers.append(conv)
            if batchnorm:
                layers.append(nn.BatchNorm2d(HIDDEN_DIM))
            if dropout:
                layers.append(nn.Dropout(DROPOUT_RATE))
            layers.append(ACTIVATION())
        self.layers = nn.SequentialCell(layers)

        # initialize output layers
        self.flatten = nn.Flatten(
        )  # convert 4-dim tensor (N,C,H,W) to 2-dim tensor(N,C*H*W)
        self.fc = nn.Dense(HIDDEN_DIM,
                           OUTPUT_DIM,
                           weight_init=Normal(0.1),
                           bias_init='zeros')
Exemplo n.º 2
0
def init_weights(net, init_type='normal', init_gain=0.02):
    """
    Initialize network weights.

    :param net: network to be initialized
    :type net: nn.Module
    :param init_type: the name of an initialization method: normal | xavier | kaiming | orthogonal
    :type init_type: str
    :param init_gain: scaling factor for normal, xavier and orthogonal.
    :type init_gain: float
    """

    for _, cell in net.cells_and_names():
        classname = cell.__class__.__name__
        if hasattr(cell, 'in_proj_layer'):
            cell.in_proj_layer = Parameter(initializer(
                HeUniform(negative_slope=math.sqrt(5)),
                cell.in_proj_layer.shape, cell.in_proj_layer.dtype),
                                           name=cell.in_proj_layer.name)
        if hasattr(cell, 'weight'):
            if init_type == 'normal':
                cell.weight = Parameter(initializer(Normal(init_gain),
                                                    cell.weight.shape,
                                                    cell.weight.dtype),
                                        name=cell.weight.name)
            elif init_type == 'xavier':
                cell.weight = Parameter(initializer(XavierUniform(init_gain),
                                                    cell.weight.shape,
                                                    cell.weight.dtype),
                                        name=cell.weight.name)
            elif init_type == "he":
                cell.weight = Parameter(initializer(
                    HeUniform(negative_slope=math.sqrt(5)), cell.weight.shape,
                    cell.weight.dtype),
                                        name=cell.weight.name)
            else:
                raise NotImplementedError(
                    'initialization method [%s] is not implemented' %
                    init_type)

            if hasattr(cell, 'bias') and cell.bias is not None:
                fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.shape)
                bound = 1 / math.sqrt(fan_in)
                cell.bias = Parameter(initializer(Uniform(bound),
                                                  cell.bias.shape,
                                                  cell.bias.dtype),
                                      name=cell.bias.name)
        elif classname.find('BatchNorm2d') != -1:
            cell.gamma = Parameter(initializer(
                Normal(1.0), cell.gamma.default_input.shape()),
                                   name=cell.gamma.name)
            cell.beta = Parameter(initializer(Zero(),
                                              cell.beta.default_input.shape()),
                                  name=cell.beta.name)

    print('initialize network weight with %s' % init_type)
Exemplo n.º 3
0
 def __init__(self, num_class=10, num_channel=1):
     super(LeNet5, self).__init__()
     self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
     self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
     self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
     self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
     self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
     self.relu = nn.ReLU()
     self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
     self.flatten = nn.Flatten()
Exemplo n.º 4
0
    def __init__(self, num_class=10, num_channel=3, count=0):
        super(myNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 16, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(16, 32, 5)

        self.fc1 = nn.Dense(3872, 128, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(128, 64, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(64, num_class, weight_init=Normal(0.02))

        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(keep_prob=0.9)
Exemplo n.º 5
0
    def __init__(self, num_classes=10, cut_layer=None):
        super().__init__()
        self.cut_layer = cut_layer

        self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.relu3 = nn.ReLU()

        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.relu4 = nn.ReLU()

        self.fc3 = nn.Dense(84, num_classes, weight_init=Normal(0.02))

        # Preparing named layers so that the model can be split and straddle
        # across the client and the server
        self.layers = []
        self.layerdict = collections.OrderedDict()
        self.layerdict['conv1'] = self.conv1
        self.layerdict['relu1'] = self.relu1
        self.layerdict['pool1'] = self.pool1
        self.layerdict['conv2'] = self.conv2
        self.layerdict['relu2'] = self.relu2
        self.layerdict['pool2'] = self.pool2
        self.layerdict['flatten'] = self.flatten
        self.layerdict['fc1'] = self.fc1
        self.layerdict['relu3'] = self.relu3
        self.layerdict['fc2'] = self.fc2
        self.layerdict['relu4'] = self.relu4
        self.layerdict['fc3'] = self.fc3
        self.layers.append('conv1')
        self.layers.append('relu1')
        self.layers.append('pool1')
        self.layers.append('conv2')
        self.layers.append('relu2')
        self.layers.append('pool2')
        self.layers.append('flatten')
        self.layers.append('fc1')
        self.layers.append('relu3')
        self.layers.append('fc2')
        self.layers.append('relu4')
        self.layers.append('fc3')
 def __init__(self, config):
     """
     The embedding lookup table for vocabulary
     Args:
         config(PANGUALPHAConfig): the config of network
     Inputs:
         input_ids: the tokenized inputs with datatype int32
     Returns:
         output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size)
         self.embedding_table: Tensor, the embedding table for the vocabulary
     """
     super(EmbeddingLookup, self).__init__()
     self.vocab_size = config.vocab_size
     self.embedding_size = config.embedding_size
     if config.load_ckpt_path:
         # Loading the embedding table from the ckpt path:
         embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy')
         if os.path.exists(embedding_path):
             e_table = np.load(embedding_path)
             e_table = Tensor(e_table, mstype.float32)
             self.embedding_table = Parameter(e_table, name="embedding_table")
         else:
             raise ValueError(f"{embedding_path} file not exits, please check whether word_embedding file exist.")
     else:
         self.embedding_table = Parameter(initializer(
             Normal(0.02), [self.vocab_size, self.embedding_size]),
                                          name="embedding_table")
     if config.word_emb_dp:
         self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
     else:
         self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
         self.gather.add_prim_attr("repeated_calc_num_direction", "left")
         if config.forward_reduce_scatter:
             self.gather.add_prim_attr("forward_type", "ReduceScatter")
     self.shape = (-1, config.seq_length, config.embedding_size)
Exemplo n.º 7
0
def init_var_dict(init_args, values):
    """
    Init parameter.

    Args:
        init_args (list): Define max and min value of parameters.
        values (list): Define name, shape and init method of parameters.

    Returns:
        dict, a dict ot Parameter.
    """
    var_map = {}
    _, _max_val = init_args
    for key, shape, init_flag in values:
        if key not in var_map.keys():
            if init_flag in ['random', 'uniform']:
                var_map[key] = Parameter(initializer(Uniform(_max_val), shape,
                                                     ms_type),
                                         name=key)
            elif init_flag == "one":
                var_map[key] = Parameter(initializer("ones", shape, ms_type),
                                         name=key)
            elif init_flag == "zero":
                var_map[key] = Parameter(initializer("zeros", shape, ms_type),
                                         name=key)
            elif init_flag == 'normal':
                var_map[key] = Parameter(initializer(Normal(_max_val), shape,
                                                     ms_type),
                                         name=key)
    return var_map
Exemplo n.º 8
0
    def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
        assert num_layers == len(num_filters), \
            'ERROR: num_deconv_layers is different len(num_deconv_filters)'
        assert num_layers == len(num_kernels), \
            'ERROR: num_deconv_layers is different len(num_deconv_filters)'

        layers = OrderedDict()
        for i in range(num_layers):
            kernel, padding, _ = \
                self._get_deconv_cfg(num_kernels[i])

            planes = num_filters[i]
            layers['deconv_{}'.format(i)] = nn.SequentialCell(
                OrderedDict([
                    ('deconv',
                     nn.Conv2dTranspose(
                         in_channels=self.inplanes,
                         out_channels=planes,
                         kernel_size=kernel,
                         stride=2,
                         pad_mode='pad',
                         padding=padding,
                         has_bias=self.deconv_with_bias,
                         weight_init=Normal(0.001),
                     )),
                    ('bn', nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)),
                    ('relu', nn.ReLU()),
                ]))
            self.inplanes = planes

        return nn.SequentialCell(layers)
Exemplo n.º 9
0
    def __init__(self, filters, n_filters, max_chars_per_token, char_embed_dim,
                 n_chars, n_highway, output_dim, activation):
        super().__init__()

        self.max_chars_per_token = max_chars_per_token

        # activation for convolutions
        if activation == 'tanh':
            self._activation = nn.Tanh()
        elif activation == 'relu':
            self._activation = nn.ReLU()
        else:
            raise ValueError("Unknown activation")

        # init char_embedding
        self.char_embedding = Embedding(n_chars + 1,
                                        char_embed_dim,
                                        embedding_table=Uniform(1.0),
                                        padding_idx=0)
        # run convolutions
        convolutions = []
        for (width, num) in filters:
            if activation == 'tanh':
                cnn_weight_init = Normal(np.sqrt(1.0 / width * char_embed_dim))
            elif activation == 'relu':
                cnn_weight_init = Uniform(0.05)
            conv = nn.Conv1d(in_channels=char_embed_dim,
                             out_channels=num,
                             kernel_size=width,
                             has_bias=True,
                             weight_init=cnn_weight_init,
                             pad_mode='valid')
            convolutions.append(conv)
        self._convolutions = nn.CellList(convolutions)

        # highway layers
        self._highways = HighWay(n_filters, n_highway, 'relu')
        # projection layer
        self._projection = nn.Dense(n_filters,
                                    output_dim,
                                    has_bias=True,
                                    weight_init=Normal(np.sqrt(1.0 /
                                                               n_filters)))
        # array operations
        self.transpose = P.Transpose()
        self.concat = P.Concat(-1)
        self.max = P.ReduceMax()
Exemplo n.º 10
0
 def __init__(self, input_size, output_size, dtype, scale=1.0):
     super(Mapping, self).__init__()
     self.output_size = output_size
     self.input_size = input_size
     self.weight = Parameter(initializer(Normal(sigma=0.02*scale), [input_size, output_size]), name="mapping_weight")
     self.bias = Parameter(initializer("zeros", [output_size,]), name="mapping_bias")
     self.dtype = dtype
     self.cast = P.Cast()
Exemplo n.º 11
0
    def __init__(self, num_class=10, num_channel=1, include_top=True):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.include_top = include_top
        if self.include_top:
            self.flatten = nn.Flatten()
            self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
            self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
            self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))

        self.scalar_summary = P.ScalarSummary()
        self.image_summary = P.ImageSummary()
        self.tensor_summary = P.TensorSummary()
        self.channel = Tensor(num_channel)
Exemplo n.º 12
0
    def __init__(self,
                 input_dim: int,
                 num_layers: int = 1,
                 activation: str = 'relu'):
        super().__init__()
        self._input_dim = input_dim
        self._layers = []
        for _ in range(num_layers):
            carry = nn.Dense(input_dim,
                             input_dim,
                             weight_init=Normal(np.sqrt(1.0 / input_dim)),
                             bias_init=Constant(-2.0))
            transform = nn.Dense(input_dim,
                                 input_dim,
                                 weight_init=Normal(np.sqrt(1.0 / input_dim)))
            self._layers.append((carry, transform))

        self._activation = nn.get_activation(activation)
Exemplo n.º 13
0
def _conv7x7(in_channel, out_channel, stride=1):
    n = 7 * 7 * out_channel
    normal = Normal(math.sqrt(2. / n))
    return nn.Conv2d(in_channel,
                     out_channel,
                     kernel_size=7,
                     stride=stride,
                     padding=3,
                     pad_mode='pad',
                     weight_init=normal)
Exemplo n.º 14
0
 def __init__(self, config):
     super(PANGUALPHAPipeline, self).__init__()
     self.backbone = PANGUALPHA_ModelPipeline(config)
     self.head = PANGUALPHA_Head(config)
     self.head.stage = config.stage_num - 1
     self.vocab_size = config.vocab_size
     self.embedding_size = config.embedding_size
     self.embedding_table = Parameter(initializer(
         Normal(0.02), [self.vocab_size, self.embedding_size]),
         name="embedding_table")
Exemplo n.º 15
0
def _conv1x1(in_channel, out_channel, stride=1):
    n = 1 * 1 * out_channel
    normal = Normal(math.sqrt(2. / n))
    return nn.Conv2d(in_channel,
                     out_channel,
                     kernel_size=1,
                     stride=stride,
                     padding=0,
                     pad_mode='same',
                     weight_init=normal)
def test_conv2d_depthwiseconv2d_initializer():
    net = nn.Conv2d(128,
                    128, (2, 3),
                    stride=4,
                    pad_mode='valid',
                    padding=0,
                    group=128,
                    weight_init=Normal())
    input_data = Tensor(np.ones([3, 128, 127, 114]), dtype=mstype.float32)
    output = net(input_data)
    assert output.shape == (3, 128, 32, 28)
Exemplo n.º 17
0
 def __init__(self,
              vocab_size,
              embedding_size,
              use_one_hot=False,
              embedding_table='normal',
              dtype=mindspore.float32,
              padding_idx=None):
     if embedding_table == 'normal':
         embedding_table = Normal(1.0)
     super().__init__(vocab_size, embedding_size, use_one_hot,
                      embedding_table, dtype, padding_idx)
Exemplo n.º 18
0
 def __init__(self, config):
     super(PANGUALPHA_EmbeddingPipeLine, self).__init__()
     self.word_embedding = EmbeddingLookupPipeline(config)
     self.position_embedding = nn.Embedding(config.seq_length,
                                            config.embedding_size,
                                            embedding_table=Normal(0.02))
     self.position_embedding.gather.shard(((1, 1), (config.dp ,)))
     self.position_embedding.expand.shard(((config.dp, 1),))
     self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
     self.dropout = nn.Dropout(1 - config.dropout_rate)
     self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
     self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
Exemplo n.º 19
0
 def __init__(self, config, input_size, output_size, scale=1.0):
     super(Mapping_output, self).__init__()
     self.output_size = output_size
     self.input_size = input_size
     self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
                                         [input_size, output_size]),
                             name="mapping_weight")
     self.bias = Parameter(initializer("zeros", [
         output_size,
     ]),
                           name="mapping_bias")
     self.dtype = config.compute_dtype
     self.cast = P.Cast()
     #self.cast.add_prim_attr("_side_effect", True)
     self.add = P.TensorAdd().shard(((config.dp, config.mp), (config.mp,)))
     self.matmul = P.MatMul().shard(((config.dp, 1), (1, config.mp)))
Exemplo n.º 20
0
    def __init__(self, block, layers, cfg, pytorch_mode=True):
        self.inplanes = 64
        extra = cfg.MODEL.EXTRA
        self.deconv_with_bias = extra.DECONV_WITH_BIAS

        super(PoseResNet, self).__init__()
        self.conv1 = nn.Conv2d(3,
                               64,
                               kernel_size=7,
                               stride=2,
                               pad_mode='pad',
                               padding=3,
                               has_bias=False)
        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU()
        if pytorch_mode:
            self.maxpool = MaxPool2dPytorch(kernel_size=3,
                                            stride=2,
                                            pad_mode='same')
            print("use pytorch-style maxpool")
        else:
            self.maxpool = nn.MaxPool2d(kernel_size=3,
                                        stride=2,
                                        pad_mode='same')
            print("use mindspore-style maxpool")
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # used for deconv layers
        self.deconv_layers = self._make_deconv_layer(
            extra.NUM_DECONV_LAYERS,
            extra.NUM_DECONV_FILTERS,
            extra.NUM_DECONV_KERNELS,
        )

        self.final_layer = nn.Conv2d(
            in_channels=extra.NUM_DECONV_FILTERS[-1],
            out_channels=cfg.MODEL.NUM_JOINTS,
            kernel_size=extra.FINAL_CONV_KERNEL,
            stride=1,
            pad_mode='pad',
            padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0,
            has_bias=True,
            weight_init=Normal(0.001),
        )
Exemplo n.º 21
0
def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
    if layer_type not in ('conv', 'bn', 'fc'):
        raise ValueError('The layer type is not known, the supported are conv, bn and fc')
    if bias:
        return Zero()
    if layer_type == 'conv':
        assert isinstance(shape, (tuple, list)) and len(
            shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
        n = shape[1] * shape[1] * shape[2]
        return Normal(math.sqrt(2.0 / n))
    if layer_type == 'bn':
        return One()
    assert isinstance(shape, (tuple, list)) and len(
        shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
    n = shape[1]
    init_range = 1.0 / math.sqrt(n)
    return Uniform(init_range)
Exemplo n.º 22
0
 def __init__(self,
              input_size,
              output_size,
              initializer_range=0.02,
              dtype=ms.float32,
              scale=1.0):
     super(Mapping, self).__init__()
     self.output_size = output_size
     self.input_size = input_size
     self.weight = Parameter(initializer(
         Normal(sigma=initializer_range * scale),
         [input_size, output_size]),
                             name="Weight")
     self.bias = Parameter(initializer("zeros", [
         output_size,
     ]),
                           name="Bias")
     self.dtype = dtype
     self.cast = P.Cast()
Exemplo n.º 23
0
 def __init__(self,
              vocab_size,
              embedding_size,
              use_one_hot_embeddings=False,
              initializer_range=0.02):
     super(EmbeddingLookup, self).__init__()
     self.vocab_size = vocab_size
     self.embedding_size = embedding_size
     self.use_one_hot_embeddings = use_one_hot_embeddings
     self.embedding_table = Parameter(initializer(
         Normal(sigma=initializer_range), [vocab_size, embedding_size]),
                                      name="embedding_table")
     self.expand = P.ExpandDims()
     self.shape_flat = (-1, )
     self.gather = P.GatherV2()
     self.one_hot = P.OneHot()
     self.on_value = Tensor(1.0, mstype.float32)
     self.off_value = Tensor(0.0, mstype.float32)
     self.array_mul = P.MatMul()
     self.reshape = P.Reshape()
     self.shape = P.Shape()
Exemplo n.º 24
0
    def __init__(
            self, 
            hidden_size,
            vocab_size, 
            sample_softmax, 
            num_sampled, 
            num_true=1,
            seed=0,
            training=True):
        super().__init__()
        self.training = training
        self.sample_softmax = sample_softmax
        self.hidden_size = hidden_size

        self.weight = Parameter(initializer(Normal(1.0 / np.sqrt(hidden_size)), (vocab_size, hidden_size), mindspore.float32))
        self.bias = Parameter(initializer(Zero(), (vocab_size), mindspore.float32))

        self.sampled_softmax_loss = SampledSoftmaxLoss(num_sampled, vocab_size, num_true, seed=seed)
        self.sparse_softmax_cross_entropy_with_logits = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
        self.matmul = nn.MatMul(False, True)
        self.reduce_mean = P.ReduceMean()
Exemplo n.º 25
0
def init_method(method, shape, name, max_val=0.01):
    """
    The method of init parameters.

    Args:
        method (str): The method uses to initialize parameter.
        shape (list): The shape of parameter.
        name (str): The name of parameter.
        max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter.

    Returns:
        Parameter.
    """
    if method in ['random', 'uniform']:
        params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
    elif method == "one":
        params = Parameter(initializer("ones", shape, ms_type), name=name)
    elif method == 'zero':
        params = Parameter(initializer("zeros", shape, ms_type), name=name)
    elif method == "normal":
        params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
    return params
Exemplo n.º 26
0
 def __init__(self):
     super(LinearNet, self).__init__()
     self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
Exemplo n.º 27
0
def dense_weight_variable():
    """The weight for dense."""
    return Normal(0.01)
Exemplo n.º 28
0
def init_linear_wt(linear):
    linear.weight.set_data(initializer(Normal(1e-4), linear.weight.shape))
    if linear.has_bias:
        linear.bias.set_data(initializer(Normal(1e-4), linear.bias.shape))
Exemplo n.º 29
0
def init_wt_normal(wt):
    wt.set_data(initializer(Normal(1e-4), wt.shape))
Exemplo n.º 30
0
    def __init__(self, block, layer_nums, in_channels, channels, out_channels,
                 strides, num_classes, is_train):
        super(ResNet, self).__init__()

        if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
            raise ValueError(
                "the length of layer_num, in_channels, out_channels list must be 4!"
            )

        self.ha3 = HardAttn(2048)
        self.is_train = is_train
        self.conv1 = _conv7x7(3, 64, stride=2)
        self.bn1 = _bn(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")

        self.layer1 = self._make_layer(block,
                                       layer_nums[0],
                                       in_channel=in_channels[0],
                                       channel=channels[0],
                                       out_channel=out_channels[0],
                                       stride=strides[0])
        self.layer2 = self._make_layer(block,
                                       layer_nums[1],
                                       in_channel=in_channels[1],
                                       channel=channels[1],
                                       out_channel=out_channels[1],
                                       stride=strides[1])
        self.layer3 = self._make_layer(block,
                                       layer_nums[2],
                                       in_channel=in_channels[2],
                                       channel=channels[2],
                                       out_channel=out_channels[2],
                                       stride=strides[2])
        self.layer4 = self._make_layer(block,
                                       layer_nums[3],
                                       in_channel=in_channels[3],
                                       channel=channels[3],
                                       out_channel=out_channels[3],
                                       stride=strides[3])

        self.max = P.ReduceMax(keep_dims=True)
        self.flatten = nn.Flatten()
        self.global_bn = _bn2_kaiming(out_channels[3])
        self.partial_bn = _bn2_kaiming(out_channels[3])
        normal = Normal(0.001)
        self.global_fc = nn.Dense(out_channels[3],
                                  num_classes,
                                  has_bias=False,
                                  weight_init=normal,
                                  bias_init='zeros')
        self.partial_fc = nn.Dense(out_channels[3],
                                   num_classes,
                                   has_bias=False,
                                   weight_init=normal,
                                   bias_init='zeros')
        self.theta_0 = Tensor(np.zeros((128, 4)), mindspore.float32)
        self.theta_6 = Tensor(np.zeros((128, 4)) + 0.6, mindspore.float32)
        self.STN = STN(128, 128)
        self.concat = P.Concat(axis=1)
        self.shape = P.Shape()
        self.tanh = P.Tanh()
        self.slice = P.Slice()
        self.split = P.Split(1, 4)