def construct(self, x): square_sum = self.hyper_map(get_square_sum, x) global_norm = F.sqrt(F.addn(square_sum)) cond = self.greater_equal(global_norm, self.clip_norm) global_norm = F.select(cond, global_norm, self.clip_norm) clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x) return clip_x
def energy_fn(R: Tensor): dr = pairwise_displacement(R) # TODO: plus \epsilon is not accurate, a safe mask is better dr = F.sqrt(reduce_sum(dr * dr, -1) + 1.1920928955078125e-07) U = relu(1 - dr) U = reduce_sum(U * U) * 0.5 * 0.5 return U
def construct(self, x): mean = self.mean(x, -1) variance = self.mean(F.square(self.sub(x, mean))) output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps))) rescaled_output = self.add(self.mul(output, self.gamma), self.beta) return rescaled_output
def construct(self, grads): square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size) square_reduce_sum = F.addn(square_sum) stage_square_reduce_sum = self.allreduce(square_reduce_sum) global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) global_norms = F.sqrt(global_square_reduce_sum) return global_norms
def construct(self, x, X_train): # Tile input x to match the number of samples in X_train x_tile = self.tile(x, (128, 1)) square_diff = F.square(x_tile - X_train) square_dist = self.sum(square_diff, 1) dist = F.sqrt(square_dist) # -dist mean the bigger the value is, the nearer the samples are values, indices = self.topk(-dist, self.k) return indices
def construct(self, output_hm, output_wh, output_off, output_kps, hm, reg_mask, ind, wh, wight_mask, hm_offset, hps_mask, landmarks): """ Construct method. """ hm_loss = self.cls_loss(output_hm, hm) # 1. focal loss, center points wh_loss = self.reg_loss(output_wh, ind, wh, wight_mask) # 2. weight and height off_loss = self.reg_loss(output_off, ind, hm_offset, wight_mask) # 3. offset lm_loss = self.reg_loss_cmask(output_kps, hps_mask, ind, landmarks) # 4. landmark loss loss = self.hm_weight * hm_loss + self.wh_weight * wh_loss + \ self.off_weight * off_loss + self.lm_weight * lm_loss # depend is needed when wight_mask and reg_mask is not been used F.depend(loss, F.sqrt(F.cast(wight_mask, mstype.float32))) F.depend(loss, F.sqrt(F.cast(reg_mask, mstype.float32))) # add print when you want to see loss detail and do debug #self.print('hm_loss=', hm_loss, 'wh_loss=', wh_loss, 'off_loss=', off_loss, 'lm_loss=', lm_loss, 'loss=', loss) return loss
def construct(self, x1, x2, y): F.same_type_shape(x1, x2) _check_reduced_shape_valid(F.shape(x1), F.shape(y), (1, ), self.cls_name) # if target > 0, 1-cosine(x1, x2) # else, max(0, cosine(x1, x2)-margin) prod_sum = self.reduce_sum(x1 * x2, (1, )) square1 = self.reduce_sum(F.square(x1), (1, )) square2 = self.reduce_sum(F.square(x2), (1, )) denom = F.sqrt(square1) * F.sqrt(square2) cosine = prod_sum / denom pos_value = 1.0 - cosine neg_value = self.maximum(cosine - self.margin, 0.0) zeros = F.zeros_like(cosine) pos_part = F.select(y == 1, pos_value, zeros) neg_part = F.select(y == -1, neg_value, zeros) output_unreduced = pos_part + neg_part return self.get_loss(output_unreduced)
def construct(self, x): x_origin_shape = self.shape(x) x_target_shape = x_origin_shape[:-1] x_shape = x_origin_shape + (1, ) x = self.reshape(x, x_shape) x = self.reshape(x, x_target_shape) mean = self.mean(x, -1) variance = self.mean(F.square(self.sub(x, mean))) output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps))) rescaled_output = self.add(self.mul(output, self.gamma), self.beta) output_shape = self.shape(rescaled_output) + (1, ) rescaled_output = self.reshape(rescaled_output, output_shape) return rescaled_output
def construct(self, grads): square_sum = self.hyper_map(get_square_sum, grads) global_norms = F.sqrt( F.addn(square_sum) / F.scalar_to_array(len(square_sum))) return global_norms
def construct(self, x): mean = self.mean(x, -1) variance = self.mean(F.square(x - mean), -1) output = (x - mean) / F.sqrt(variance + self.eps) rescaled_output = output * self.gamma + self.beta return rescaled_output
def construct(self, logits, label): rmse_loss = F.sqrt(self.MSELoss(logits, label)) return rmse_loss
def construct(self, grads): square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) return global_norms
def construct(self, logits, label): _check_shape(logits.shape, label.shape) rmse_loss = F.sqrt(self.MSELoss(logits, label)) return rmse_loss