def _truncated_normal(self, shape, dtype): if compat.forward_compatible(2020, 10, 25): key, counter = self._prepare_key_counter(shape) return gen_stateless_random_ops_v2.stateless_truncated_normal_v2( shape=shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) return gen_stateful_random_ops.stateful_truncated_normal( self.state.handle, self.algorithm, shape, dtype=dtype)
def _truncated_normal(self, shape, dtype): return gen_stateful_random_ops.stateful_truncated_normal( self.state.handle, self.algorithm, shape, dtype=dtype)