def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1, group=1, has_bias=False, weight_init='normal', bias_init='zeros'): kernel_size = twice(kernel_size) stride = twice(stride) self._dilation = dilation dilation = twice(dilation) super(Conv2d_Thor, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group, has_bias, weight_init, bias_init) self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation, group=self.group) self._init_depthwise_conv2d(weight_init) self.bias_add = P.BiasAdd() self.thor = True self.hw = kernel_size[0] * kernel_size[1] self.matrix_A_dim = self.in_channels * self.kernel_size[ 0] * self.kernel_size[1] self.matrix_G_dim = self.out_channels self.shape = P.Shape() self.reshape = P.Reshape() self.mul = P.Mul() self.cast = P.Cast() self.A_normalizer = Parameter(initializer(0, [1], mstype.float32), name="A_normalizer", requires_grad=False) self.G_normalizer = Parameter(initializer(0, [1], mstype.float32), name="G_normalizer", requires_grad=False) self.is_Ascend = True if context.get_context("device_target") == "Ascend": ksizes = (1, kernel_size[0], kernel_size[1], 1) strides = (1, stride[0], stride[1], 1) self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides) self.cube_matmul = P.CusMatMulCube(transpose_a=True) self.transpose02314 = P.CusTranspose02314() dampingA_dim = self.matrix_A_dim self.diag_block_dim = 128 if (self.matrix_A_dim % self.diag_block_dim ) != 0 and self.matrix_A_dim > self.diag_block_dim: dampingA_dim = (self.matrix_A_dim // self.diag_block_dim + 1) * self.diag_block_dim dampingG_dim = self.matrix_G_dim if (self.matrix_G_dim % self.diag_block_dim ) != 0 and self.matrix_G_dim > self.diag_block_dim: dampingG_dim = (self.matrix_G_dim // self.diag_block_dim + 1) * self.diag_block_dim self.matrix_A_cov = Parameter(Tensor( np.zeros([dampingA_dim, dampingA_dim]).astype(np.float32)), name='matrix_A', requires_grad=False) self.matrix_G_cov = Parameter(Tensor( np.zeros([dampingG_dim, dampingG_dim]).astype(np.float32)), name='matrix_G', requires_grad=False) self.channels_slice_flag = False self.C0 = 16 if self.in_channels % self.C0 != 0: self.channels_slice_flag = True self.padA_flag = False if (self.matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_A_dim \ and self.matrix_A_dim > self.diag_block_dim: self.padA_flag = True pad_dim = self.diag_block_dim - self.matrix_A_dim % self.diag_block_dim self.padA = P.Pad(((0, pad_dim), (0, pad_dim))) self.slice = P.Slice() else: self.is_Ascend = False self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same") self.matmul = P.MatMul(transpose_b=True) self.reduce_mean = P.ReduceMean(keep_dims=False) self.matrix_A_cov = Parameter(Tensor( np.zeros([self.matrix_A_dim, self.matrix_A_dim]).astype(np.float32)), name='matrix_A', requires_grad=False) self.matrix_G_cov = Parameter(Tensor( np.zeros([self.matrix_G_dim, self.matrix_G_dim]).astype(np.float32)), name='matrix_G', requires_grad=False) self.getG = P.InsertGradientOf(self.save_gradient)
def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_mode='same', padding=0, dilation=1, group=1, data_format='NCHW', has_bias=False, weight_init='normal', damping=0.03, loss_scale=1, frequency=278, batch_size=32, bias_init='zeros'): self.skfac = True self.hw = kernel_size * kernel_size kernel_size = twice(kernel_size) super(Conv2d_SKFAC_GPU, self).__init__( in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, group, data_format, has_bias, weight_init, bias_init, ) self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation, group=self.group) self.matrix_A_dim = self.in_channels * self.kernel_size[ 0] * self.kernel_size[1] self.matrix_G_dim = self.out_channels split_dim = 128 self.matrix_A_inv = Parameter(np.zeros( (self.matrix_A_dim, self.matrix_A_dim)).astype(np.float32), requires_grad=False) self.matrix_G_inv = Parameter(np.zeros( (self.matrix_G_dim, self.matrix_G_dim)).astype(np.float32), requires_grad=False) self.cov_step = Parameter(initializer(0, [1], mstype.int32), requires_grad=False) self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same") self.matmul = P.MatMul(transpose_a=True) self.matmul_ = P.MatMul() self.shape = P.Shape() self.reshape = P.Reshape() self.mul = P.Mul() self.getG = P.InsertGradientOf(self.save_gradient) self.loss_scale = Tensor(1 / loss_scale, mstype.float16) self.batch_size = Tensor(batch_size, mstype.float16) self.transpose = P.Transpose() self.cast = P.Cast() self.gather = P.Gather() self.freq = Tensor(frequency, mstype.int32) self.axis = 0 self.sqrt = P.Sqrt() self.reduce_mean = P.ReduceMean(keep_dims=False) self.damping = Parameter(Tensor(damping), requires_grad=False) self.dampingA = Tensor(np.identity(batch_size), mstype.float32) self.dampingG = Tensor(np.identity(batch_size), mstype.float32) self.I_G = Tensor(np.identity(out_channels), mstype.float32) self.I_A = Tensor(np.identity(self.matrix_A_dim), mstype.float32) self.cholesky = P.CholeskyTrsm(split_dim=split_dim) self.vector_matmul = P.BatchMatMul(transpose_a=True) self.batch_coefficient = Tensor((1 / 32)**0.5, mstype.float32)