def construct(self, x): square_sum = self.hyper_map(get_square_sum, x) global_norm = F.sqrt(F.addn(square_sum)) cond = self.greater_equal(global_norm, self.clip_norm) global_norm = F.select(cond, global_norm, self.clip_norm) clip_x = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), x) return clip_x
def construct(self, grads): square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size) square_reduce_sum = F.addn(square_sum) stage_square_reduce_sum = self.allreduce(square_reduce_sum) global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) global_norms = F.sqrt(global_square_reduce_sum) return global_norms
def construct(self, grads): square_sum = self.hyper_map(get_square_sum, grads) global_norms = F.sqrt( F.addn(square_sum) / F.scalar_to_array(len(square_sum))) return global_norms
def construct(self, grads): square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) return global_norms