示例#1
0
 def _truncated_normal(self, shape, dtype):
     if compat.forward_compatible(2020, 10, 25):
         key, counter = self._prepare_key_counter(shape)
         return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
             shape=shape,
             key=key,
             counter=counter,
             dtype=dtype,
             alg=self.algorithm)
     return gen_stateful_random_ops.stateful_truncated_normal(
         self.state.handle, self.algorithm, shape, dtype=dtype)
 def _truncated_normal(self, shape, dtype):
     return gen_stateful_random_ops.stateful_truncated_normal(
         self.state.handle, self.algorithm, shape, dtype=dtype)