コード例 #1
0
 def construct(self, x):
     square_sum = self.hyper_map(get_square_sum, x)
     global_norm = F.sqrt(F.addn(square_sum))
     cond = self.greater_equal(global_norm, self.clip_norm)
     global_norm = F.select(cond, global_norm, self.clip_norm)
     clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x)
     return clip_x
コード例 #2
0
 def construct(self, grads):
     square_sum = self.hyper_map(get_square_sum, grads,
                                 self.allreduce_group_size)
     square_reduce_sum = F.addn(square_sum)
     stage_square_reduce_sum = self.allreduce(square_reduce_sum)
     global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum)
     global_norms = F.sqrt(global_square_reduce_sum)
     return global_norms
コード例 #3
0
ファイル: utils.py プロジェクト: pingping1122/mindspore
 def construct(self, grads):
     square_sum = self.hyper_map(get_square_sum, grads)
     global_norms = F.sqrt(
         F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
     return global_norms
コード例 #4
0
 def construct(self, grads):
     square_sum_dp = self.hyper_map(get_square_sum, grads, self.values)
     global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
     return global_norms