def __init__(self): super(GatherFlipFeature, self).__init__() self.gather_nd = ops.GatherNd() self.transpose = ops.Transpose() self.perm_list = (1, 0, 2, 3) self.shape = ops.Shape() self.reshape = ops.Reshape()
def __init__(self): super(TransposeGatherFeature, self).__init__() self.shape = ops.Shape() self.reshape = ops.Reshape() self.transpose = ops.Transpose() self.perm_list = (0, 2, 3, 1) self.gather_feat = GatherFeature()
def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) return ops.depend( loss, self.hyper_map(ops.partial(_sum_op), self.grad_sum, grads))
def construct(self, input_ids, token_type_ids, input_mask): """Bidirectional Encoder Representations from Transformers.""" # embedding embedding_tables = self.bert_embedding_lookup.embedding_table word_embeddings = self.bert_embedding_lookup(input_ids) embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings) # attention mask [batch_size, seq_length, seq_length] attention_mask = self._create_attention_mask_from_input_mask(input_mask) # bert encoder encoder_output = self.bert_encoder(self.cast_compute_type(embedding_output), attention_mask) sequence_output = self.cast(encoder_output[self.last_idx], self.dtype) # pooler batch_size = P.Shape()(input_ids)[0] sequence_slice = self.slice(sequence_output, (0, 0, 0), (batch_size, 1, self.hidden_size), (1, 1, 1)) first_token = self.squeeze_1(sequence_slice) pooled_output = self.dense(first_token) pooled_output = self.cast(pooled_output, self.dtype) return encoder_output
def __init__(self, ks): super(RegenerateFeatureMap, self).__init__() self.ks = ks self.shape = ops.Shape() self.reshape = ops.Reshape() self.split = ops.Split(axis=-1, output_num=ks) self.concat = ops.Concat(axis=2)
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, has_bias=False, modulation=True): super(DeformConv2d, self).__init__() self.kernel_size = kernel_size self.padding = padding self.stride = stride self.zero_padding = nn.Pad(((0, 0), (0, 0), (padding, padding), (padding, padding))) self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, pad_mode='valid', padding=0, stride=kernel_size, has_bias=has_bias) self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=self.kernel_size, pad_mode='pad', padding=self.padding, stride=self.stride) self.modulation = modulation if modulation: self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=self.kernel_size, pad_mode='valid', padding=0, stride=self.stride) if kernel_size % 2 == 0: raise ValueError("Only odd number is supported, but current kernel sizeis {}".format(kernel_size)) self.N = kernel_size * kernel_size self.begin = kernel_size // 2 self.sigmoid = ops.Sigmoid() self.dtype = ops.DType() self.perm_list = (0, 2, 3, 1) self.transpose = ops.Transpose() self.floor = ops.Floor() self.half = ops.Split(axis=-1, output_num=2) self.clip_value = ClipByValue() self.expand_dims = ops.ExpandDims() self.shape = ops.Shape() self.cast = ops.Cast() self._get_offset = GetOffsetPosition(self.begin, self.stride) self._get_surround = GetSurroundFeature() self._generate_fm = RegenerateFeatureMap(self.kernel_size)
def __init__(self, length, depth, max_relative_position, initializer_range, use_one_hot_embeddings=False): super(RelaPosEmbeddingsGenerator, self).__init__() self.depth = depth self.vocab_size = max_relative_position * 2 + 1 self.use_one_hot_embeddings = use_one_hot_embeddings self.embeddings_table = Parameter( initializer(TruncatedNormal(initializer_range), [self.vocab_size, self.depth]), name='embeddings_for_position') self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, max_relative_position=max_relative_position) self.reshape = ops.Reshape() self.one_hot = ops.OneHot() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.shape = ops.Shape() self.gather = ops.GatherV2() # index_select self.matmul = ops.BatchMatMul()
def construct(self, lr, hr, width_mult, tea_width_mult): weights = self.weights loss = self.network(lr, hr, width_mult, tea_width_mult) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(lr, hr, width_mult, tea_width_mult, sens) self.optimizer(grads) return loss
def __init__(self): super(GatherTopKChannel, self).__init__() self.shape = ops.Shape() self.reshape = ops.Reshape() self.topk = ops.TopK(sorted=True) self.cast = ops.Cast() self.dtype = ops.DType() self.mod = ops.Mod() self.div = ops.Div()
def construct(self, realA, realB): """ Define TrainOneStepCell. """ d_loss = self.loss_netD(realA, realB) g_loss = self.loss_netG(realA, realB) d_sens = ops.Fill()(ops.DType()(d_loss), ops.Shape()(d_loss), self.sens) d_grads = self.grad(self.loss_netD, self.weights_D)(realA, realB, d_sens) d_res = ops.depend(d_loss, self.optimizerD(d_grads)) g_sens = ops.Fill()(ops.DType()(g_loss), ops.Shape()(g_loss), self.sens) g_grads = self.grad(self.loss_netG, self.weights_G)(realA, realB, g_sens) g_res = ops.depend(g_loss, self.optimizerG(g_grads)) return d_res, g_res
def construct(self, img_A, img_B, fake_A, fake_B): weights = self.weights ld = self.D(img_A, img_B, fake_A, fake_B) sens_d = ops.Fill()(ops.DType()(ld), ops.Shape()(ld), self.sens) grads_d = self.grad(self.D, weights)(img_A, img_B, fake_A, fake_B, sens_d) if self.reducer_flag: # apply grad reducer on grads grads_d = self.grad_reducer(grads_d) return ops.depend(ld, self.optimizer(grads_d))
def construct(self, img_A, img_B): weights = self.weights fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B) sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens) grads_g = self.grad(self.net, weights)(img_A, img_B, sens) if self.reducer_flag: # apply grad reducer on grads grads_g = self.grad_reducer(grads_g) return fake_A, fake_B, ops.depend( lg, self.optimizer(grads_g)), lga, lgb, lca, lcb, lia, lib
def __init__(self): super(GetSurroundFeature, self).__init__() self.shape = ops.Shape() self.concat = ops.Concat(axis=1) self.reshape = ops.Reshape() self.half = ops.Split(axis=-1, output_num=2) self.tile = ops.Tile() self.gather_nd = ops.GatherNd() self.transpose = ops.Transpose() self.perm_list = (0, 2, 3, 1) self.order_list = (0, 3, 1, 2) self.expand_dims = ops.ExpandDims()
def __init__(self, enable_cpu_gather=True): super(GatherFeature, self).__init__() self.tile = ops.Tile() self.shape = ops.Shape() self.concat = ops.Concat(axis=1) self.reshape = ops.Reshape() self.enable_cpu_gather = enable_cpu_gather if self.enable_cpu_gather: self.gather_nd = ops.GatherD() self.expand_dims = ops.ExpandDims() else: self.gather_nd = ops.GatherND()
def __init__(self, begin, stride): super(GetOffsetPosition, self).__init__() self.begin = begin self.stride = stride self.meshgrid = ops.Meshgrid() self.shape = ops.Shape() self.reshape = ops.Reshape() self.cat_a0 = ops.Concat(axis=0) self.cat_a1 = ops.Concat(axis=1) self.tile = ops.Tile() self.dtype = ops.DType() self.range = nn.Range(-self.begin, self.begin + 1) self.cast = ops.Cast()
def __init__(self, net_config, K=100, enable_nms_fp16=True): super(DetectionDecode, self).__init__() self.K = K self.nms = NMS(enable_nms_fp16=enable_nms_fp16) self.shape = ops.Shape() self.gather_topk = GatherTopK() self.half = ops.Split(axis=-1, output_num=2) self.add = ops.TensorAdd() self.concat_a2 = ops.Concat(axis=2) self.trans_gather_feature = TransposeGatherFeature() self.expand_dims = ops.ExpandDims() self.reshape = ops.Reshape() self.reg_offset = net_config.reg_offset self.Sigmoid = nn.Sigmoid()
def __init__(self, alpha=2, beta=4): super(FocalLoss, self).__init__() self.alpha = alpha self.beta = beta self.pow = ops.Pow() self.log = ops.Log() self.select = ops.Select() self.equal = ops.Equal() self.less = ops.Less() self.cast = ops.Cast() self.fill = ops.Fill() self.dtype = ops.DType() self.shape = ops.Shape() self.reduce_sum = ops.ReduceSum()
def construct(self, *inputs): """Defines the computation performed.""" weights = self.weights loss = self.network(*inputs) sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) if self.accumulation and self.accumulation_steps > 1: accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) loss = ops.depend(loss, accu_succ) if self.accumulation: succ = False else: grads = self.grad_reducer(grads) accu_grads = ops.depend(self.accu_grads, grads) accu_succ = self.hyper_map(reset_accu_grads, accu_grads) loss = ops.depend(loss, accu_succ) succ = self.optimizer(grads) return ops.depend(loss, succ)
def __init__(self, net_config, K=100, enable_nms_fp16=True): super(MultiPoseDecode, self).__init__() self.K = K self.nms = NMS(enable_nms_fp16=enable_nms_fp16) self.shape = ops.Shape() self.gather_topk = GatherTopK() self.gather_topk_channel = GatherTopKChannel() self.gather_by_ind = GatherFeatureByInd() self.half = ops.Split(axis=-1, output_num=2) self.half_first = ops.Split(axis=0, output_num=2) self.split = ops.Split(axis=-1, output_num=4) self.flip_lr = FlipLR() self.flip_lr_off = FlipLROff() self.flip_tensor = FlipTensor() self.concat = ops.Concat(axis=1) self.concat_a2 = ops.Concat(axis=2) self.concat_a3 = ops.Concat(axis=3) self.trans_gather_feature = TransposeGatherFeature() self.expand_dims = ops.ExpandDims() self.reshape = ops.Reshape() self.add = ops.TensorAdd() self.dtype = ops.DType() self.cast = ops.Cast() self.thresh = 0.1 self.transpose = ops.Transpose() self.perm_list = (0, 2, 1, 3) self.tile = ops.Tile() self.greater = ops.Greater() self.square = ops.Square() self.sqrt = ops.Sqrt() self.reduce_sum = ops.ReduceSum() self.min = ops.ArgMinWithValue(axis=3) self.max = ops.Maximum() self.hm_hp = net_config.hm_hp self.dense_hp = net_config.dense_hp self.reg_offset = net_config.reg_offset self.reg_hp_offset = net_config.reg_hp_offset self.hm_hp_ind = 3 if self.hm_hp else 2 self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
def __init__(self, length, depth, max_relative_position, initializer_range, use_one_hot_embeddings=False): super(RelaPosEmbeddingsGenerator, self).__init__() self.depth = depth self.vocab_size = max_relative_position * 2 + 1 self.use_one_hot_embeddings = use_one_hot_embeddings self.embeddings_table = Parameter( initializer(TruncatedNormal(initializer_range), [self.vocab_size, self.depth])) self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, max_relative_position=max_relative_position) self.reshape = P.Reshape() self.one_hot = nn.OneHot(depth=self.vocab_size) self.shape = P.Shape() self.gather = P.Gather() # index_select self.matmul = P.BatchMatMul()
def construct(self, input_ids, token_type_id, input_mask): """construct BertPoetryModel""" input_shape = ops.Shape()(input_mask) shape_right = (input_shape[0], 1, input_shape[1]) shape_left = input_shape + (1,) input_mask = self.cast(input_mask, mstype.float32) mask_left = self.reshape(input_mask, shape_left) mask_right = self.reshape(input_mask, shape_right) attention_mask = self.batch_matmul(mask_left, mask_right) attention_mask = self.multiply(attention_mask, self.lower_triangle_mask) sequence_output, _, embedding_tables = self.bert(input_ids, token_type_id, attention_mask) bert_output = ops.Reshape()(sequence_output, (-1, self.hidden_size)) MLM_output = self.MLM_Dense(bert_output) MLM_output = self.layer_norm(MLM_output) embedding_tables = ops.Cast()(embedding_tables, mstype.float16) output = self.matmul(MLM_output, embedding_tables) output = ops.Cast()(output, mstype.float32) output = output + self.biasadd output = ops.Reshape()(output, (-1, self.seq_length, self.num_tokens)) logits = self.softmax(output) return logits
def __init__(self, net_config, K=100, enable_nms_fp16=True): super(CenterNetDetEval, self).__init__() self.network = GatherDetectionFeatureCell(net_config) self.decode = DetectionDecode(net_config, K, enable_nms_fp16) self.shape = ops.Shape() self.reshape = ops.Reshape()
def __init__(self, net_config, K=100, enable_nms_fp16=True): super(CenterNetMultiPoseEval, self).__init__() self.network = GatherMultiPoseFeatureCell(net_config) self.decode = MultiPoseDecode(net_config, K, enable_nms_fp16) self.shape = ops.Shape() self.reshape = ops.Reshape()