def _parameter_control_dependencies(self, is_init): if not self.validate_args: return [] sample_shape = tf.concat( [self._batch_shape_tensor(), self._event_shape_tensor()], axis=0) low = None if self._low is None else tf.convert_to_tensor(self._low) high = None if self._high is None else tf.convert_to_tensor(self._high) assertions = [] if self._low is not None and is_init != tensor_util.is_ref(self._low): low_shape = ps.shape(low) broadcast_shape = ps.broadcast_shape(sample_shape, low_shape) assertions.extend( [distribution_util.assert_integer_form( low, message='`low` has non-integer components.'), assert_util.assert_equal( tf.reduce_prod(broadcast_shape), tf.reduce_prod(sample_shape), message=('Shape of `low` adds extra batch dimensions to ' 'sample shape.'))]) if self._high is not None and is_init != tensor_util.is_ref(self._high): high_shape = ps.shape(high) broadcast_shape = ps.broadcast_shape(sample_shape, high_shape) assertions.extend( [distribution_util.assert_integer_form( high, message='`high` has non-integer components.'), assert_util.assert_equal( tf.reduce_prod(broadcast_shape), tf.reduce_prod(sample_shape), message=('Shape of `high` adds extra batch dimensions to ' 'sample shape.'))]) if (self._low is not None and self._high is not None and (is_init != (tensor_util.is_ref(self._low) or tensor_util.is_ref(self._high)))): assertions.append(assert_util.assert_less( low, high, message='`low` must be strictly less than `high`.')) return assertions
def _calculate_mean_and_var(self, x, axes, keep_dims): with backend.name_scope('moments'): # The dynamic range of fp16 is too limited to support the collection of # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x replica_ctx = tf.distribute.get_replica_context() if replica_ctx: local_sum = tf.reduce_sum(y, axis=axes, keepdims=True) local_squared_sum = tf.reduce_sum(tf.square(y), axis=axes, keepdims=True) batch_size = tf.cast(tf.shape(y)[axes[0]], tf.float32) # TODO(b/163099951): batch the all-reduces once we sort out the ordering # issue for NCCL. We don't have a mechanism to launch NCCL in the same # order in each replica nowadays, so we limit NCCL to batch all-reduces. y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum) y_squared_sum = replica_ctx.all_reduce( tf.distribute.ReduceOp.SUM, local_squared_sum) global_batch_size = replica_ctx.all_reduce( tf.distribute.ReduceOp.SUM, batch_size) axes_vals = [(tf.shape(y))[axes[i]] for i in range(1, len(axes))] multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32) multiplier = multiplier * global_batch_size mean = y_sum / multiplier y_squared_mean = y_squared_sum / multiplier # var = E(x^2) - E(x)^2 variance = y_squared_mean - tf.square(mean) else: # Compute true mean while keeping the dims for proper broadcasting. mean = tf.reduce_mean(y, axes, keepdims=True, name='mean') # sample variance, not unbiased variance # Note: stop_gradient does not change the gradient that gets # backpropagated to the mean from the variance calculation, # because that gradient is zero variance = tf.reduce_mean(tf.math.squared_difference( y, tf.stop_gradient(mean)), axes, keepdims=True, name='variance') if not keep_dims: mean = tf.compat.v1.squeeze(mean, axes) variance = tf.compat.v1.squeeze(variance, axes) if x.dtype == tf.float16: return (tf.cast(mean, tf.float16), tf.cast(variance, tf.float16)) else: return (mean, variance)
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError('`entropy` is not implemented.') if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError('`entropy` is not implemented when ' '`bijector` is not injective.') distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) override_event_shape = tf.convert_to_tensor(self._override_event_shape) override_batch_shape = tf.convert_to_tensor(self._override_batch_shape) base_batch_shape_tensor = self.distribution.batch_shape_tensor() base_event_shape_tensor = self.distribution.event_shape_tensor() # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast(tf.reduce_prod(override_event_shape), dtype=dtype_util.base_dtype( entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(override_batch_shape), base_batch_shape_tensor ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ override_batch_shape, prefer_static.ones_like(base_batch_shape_tensor) ], 0) entropy = tf.tile(entropy, multiples) # Create a dummy event of zeros to pass to # `bijector.inverse_log_det_jacobian` to extract the constant Jacobian. event_shape_tensor = self._event_shape_tensor(override_event_shape, base_event_shape_tensor) event_ndims = tf.nest.map_structure(prefer_static.rank_from_shape, event_shape_tensor, self.event_shape) dummy = tf.nest.map_structure(prefer_static.zeros, event_shape_tensor, self.dtype) ildj = self.bijector.inverse_log_det_jacobian(dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def get_jacobian_fn_mat(jacobian_fn, ode_fn_vec, state_shape, use_pfor): """Returns a wrapper around the user-specified `jacobian_fn` argument. `jacobian_fn` is an optional argument that can either be a constant `Tensor` or a function of the form `jacobian_fn(time, state)`. This function returns a wrapper `jacobian_fn_mat(time, state_vec)` whose second argument and output are 1 and 2-D `Tensor`s, respectively, corresponding reshaped versions of `state` and `jacobian_fn(time, state)`. Args: jacobian_fn: User-specified `jacobian_fn` passed to `solve`. ode_fn_vec: User-specified `ode_fn` passed to `solve`. state_shape: The shape of the second argument and output of `ode_fn`. use_pfor: User-specified `use_pfor` passed to `solve`. Returns: The wrapper described above. """ if jacobian_fn is None: def automatic_jacobian_fn_mat(time, state_vec): with tf.GradientTape(watch_accessed_variables=False, persistent=not use_pfor) as tape: tape.watch(state_vec) outputs = ode_fn_vec(time, state_vec) jacobian_mat = tape.jacobian(outputs, state_vec, experimental_use_pfor=use_pfor) if jacobian_mat is None: return tf.zeros([tf.size(state_vec)] * 2, dtype=state_vec.dtype) return jacobian_mat return automatic_jacobian_fn_mat if not callable(jacobian_fn): constant_jacobian_mat = tf.reshape( tf.convert_to_tensor(jacobian_fn), [-1, tf.reduce_prod(state_shape)]) def constant_jacobian_fn_mat(*_): return constant_jacobian_mat return constant_jacobian_fn_mat def jacobian_fn_mat(time, state_vec): state = tf.reshape(state_vec, state_shape) jacobian_mat = tf.reshape(jacobian_fn(time, state), [-1, tf.size(state)]) return jacobian_mat return jacobian_fn_mat
def _event_shape_tensor(self): event_sizes = tf.nest.map_structure(tensorshape_util.num_elements, self._distribution.event_shape) if any(s is None for s in tf.nest.flatten(event_sizes)): event_sizes = tf.nest.map_structure( lambda static_size, shape_tensor: # pylint: disable=g-long-lambda (tf.reduce_prod(shape_tensor) if static_size is None else static_size), event_sizes, self._distribution.event_shape_tensor()) return tf.reduce_sum(tf.nest.flatten(event_sizes))[tf.newaxis]
def _get_leftmost_dim_size(x, name=None): """Returns the size of the left most dimension, statically if possible.""" with tf.name_scope(name or 'get_leftmost_dim_size'): x = tf.convert_to_tensor(value=x, name='x') if x.shape.ndims is None: # If tf.shape(x) is scalar, the [:1] will produce the empty list, whose # reduce_prod is 1 as desired. Otherwise, the [:1] will select the first # dimension, and reduce_prod will not alter it. return tf.reduce_prod(input_tensor=tf.shape(input=x)[:1]) if x.shape.ndims == 0: return 1 leftmost = tf.compat.dimension_value(x.shape[0]) return leftmost if leftmost is not None else tf.shape(input=x)[0]
def make_2d(tensor, split_dim): """Reshapes an N-dimensional tensor into a 2D tensor. Dimensions before (excluding) and after (including) `split_dim` are grouped together. Args: tensor: a tensor of shape `(d0, ..., d(N-1))`. split_dim: an integer from 1 to N-1, index of the dimension to group dimensions before (excluding) and after (including). Returns: Tensor of shape `(d0 * ... * d(split_dim-1), d(split_dim) * ... * d(N-1))`. """ shape = tf.compat.v1.shape(tensor) in_dims = shape[:split_dim] out_dims = shape[split_dim:] in_size = tf.reduce_prod(in_dims) out_size = tf.reduce_prod(out_dims) return tf.reshape(tensor, (in_size, out_size))
def _mean(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size num_states observation_event_size flat_mean = tf.einsum("ijk,jkl->jil", flat_probs, flat_means) # flat_mean :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_mean, unflat_mean_shape)
def _finish_prob_for_one_fiber(self, y, x, ildj, event_ndims, **distribution_kwargs): """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = tf.reduce_prod(prob, axis=self._reduce_event_indices) prob = prob * tf.exp(tf.cast(ildj, prob.dtype)) if self._is_maybe_event_override and isinstance(event_ndims, int): tensorshape_util.set_shape( prob, tf.broadcast_static_shape( tensorshape_util.with_rank_at_least(y.shape, 1)[:-event_ndims], self.batch_shape)) return prob
def _mean(self): observation_distribution = self.observation_distribution batch_shape = self.batch_shape_tensor() num_states = self.transition_distribution.batch_shape_tensor()[-1] probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ batch_shape, [num_states], observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(batch_shape) flat_probs_shape = [self._num_steps, batch_size, num_states] flat_means_shape = [ batch_size, num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size num_states observation_event_size flat_mean = tf.einsum('ijk,jkl->jil', flat_probs, flat_means) # flat_mean :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat( [batch_shape, [self._num_steps], observation_event_shape], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_mean, unflat_mean_shape)
def test_docstring_example(self): # Produce the first 1000 members of the Halton sequence in 3 dimensions. num_results = 1000 dim = 3 sample, params = random.halton.sample(dim, num_results=num_results, seed=127) # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional # hypercube. powers = tf.range(1., limit=dim + 1) integral = tf.reduce_mean( input_tensor=tf.reduce_prod(input_tensor=sample**powers, axis=-1)) true_value = 1. / tf.reduce_prod(input_tensor=powers + 1.) # Produces a relative absolute error of 1.7%. self.assertAllClose(self.evaluate(integral), self.evaluate(true_value), rtol=0.02) # Now skip the first 1000 samples and recompute the integral with the next # thousand samples. The sequence_indices argument can be used to do this. sequence_indices = tf.range(start=1000, limit=1000 + num_results, dtype=tf.int32) sample_leaped, _ = random.halton.sample( dim, sequence_indices=sequence_indices, randomization_params=params) integral_leaped = tf.reduce_mean(input_tensor=tf.reduce_prod( input_tensor=sample_leaped**powers, axis=-1)) self.assertAllClose(self.evaluate(integral_leaped), self.evaluate(true_value), rtol=0.05)
def _make_flatten_unflatten_fns_tf(batch_shape): """Returns functions to flatten and unflatten a batch shape.""" batch_shape = tf.cast(batch_shape, dtype=tf.int32) batch_rank = batch_shape.shape[0] batch_ndims = tf.reduce_prod(batch_shape) @tf.function def flatten_fn(x): flat_shape = tf.concat([[batch_ndims], tf.shape(x)[batch_rank:]], axis=0) return tf.reshape(x, flat_shape) @tf.function def unflatten_fn(x): full_shape = tf.concat([batch_shape, tf.shape(x)[1:]], axis=0) return tf.reshape(x, full_shape) return flatten_fn, unflatten_fn
def _mode(self, samples=None): # Samples count can vary by batch member. Use map_fn to compute mode for # each batch separately. def _get_mode(samples): _, idx, count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0]) # TODO(b/161402486): Remove this hack for fixing the wrong static shape # of `idx` in graph mode. idx = tf.vectorized_map(lambda x: tf.reshape(x, [-1])[0], idx) # NOTE: # - `count` has shape `[K]`, where `K` is the number of unique elements, # and `count[j]` is the number of times the j-th unique element occurs # in `samples`. # - `idx` has shape `[samples.shape[0]]`, and `idx[i] == j` means that # `samples[i]` is equal to the `j`-th unique element. max_count_idx = tf.argmax(count, output_type=tf.int32) # Return an index `i` for which `idx[i] == max_count_idx`. return tf.argmax(tf.cast(tf.math.equal(idx, max_count_idx), dtype=tf.int32), output_type=tf.int32) if samples is None: samples = tf.convert_to_tensor(self._samples) num_samples = self._compute_num_samples(samples) # Flatten samples for each batch. if self._event_ndims == 0: flattened_samples = tf.reshape(samples, [-1, num_samples]) mode_shape = self._batch_shape_tensor(samples) else: event_size = tf.reduce_prod(self._event_shape_tensor(samples)) mode_shape = ps.concat([ self._batch_shape_tensor(samples), self._event_shape_tensor(samples) ], axis=0) flattened_samples = tf.reshape(samples, [-1, num_samples, event_size]) indices = tf.map_fn(_get_mode, flattened_samples, fn_output_signature=tf.int32) full_indices = tf.stack([tf.range(tf.shape(indices)[0]), indices], axis=1) mode = tf.gather_nd(flattened_samples, full_indices) return tf.reshape(mode, mode_shape)
def _entropy(self, **kwargs): if not self.bijector.is_constant_jacobian: raise NotImplementedError("entropy is not implemented") if not self.bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError("entropy is not implemented when " "bijector is not injective.") distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs) # Suppose Y = g(X) where g is a diffeomorphism and X is a continuous rv. It # can be shown that: # H[Y] = H[X] + E_X[(log o abs o det o J o g)(X)]. # If is_constant_jacobian then: # E_X[(log o abs o det o J o g)(X)] = (log o abs o det o J o g)(c) # where c can by anything. entropy = self.distribution.entropy(**distribution_kwargs) if self._is_maybe_event_override: # H[X] = sum_i H[X_i] if X_i are mutually independent. # This means that a reduce_sum is a simple rescaling. entropy = entropy * tf.cast( tf.reduce_prod(self._override_event_shape), dtype=dtype_util.base_dtype(entropy.dtype)) if self._is_maybe_batch_override: new_shape = tf.concat([ prefer_static.ones_like(self._override_batch_shape), self.distribution.batch_shape_tensor() ], 0) entropy = tf.reshape(entropy, new_shape) multiples = tf.concat([ self._override_batch_shape, prefer_static.ones_like(self.distribution.batch_shape_tensor()) ], 0) entropy = tf.tile(entropy, multiples) dummy = prefer_static.zeros( shape=tf.concat( [self.batch_shape_tensor(), self.event_shape_tensor()], 0), dtype=self.dtype) event_ndims = ( tensorshape_util.rank(self.event_shape) if tensorshape_util.rank(self.event_shape) is not None else tf.size( self.event_shape_tensor())) ildj = self.bijector.inverse_log_det_jacobian( dummy, event_ndims=event_ndims, **bijector_kwargs) entropy = entropy - tf.cast(ildj, entropy.dtype) tensorshape_util.set_shape(entropy, self.batch_shape) return entropy
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. distributions = self.poisson_and_mixture_distributions() dist, mixture_dist = distributions batch_size = tensorshape_util.num_elements(self.batch_shape) if batch_size is None: batch_size = tf.reduce_prod( self._batch_shape_tensor(distributions=distributions)) # We need to 'sample extra' from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. mixture_seed, poisson_seed = samplers.split_seed( seed, salt='PoissonLogNormalQuadratureCompound') ids = mixture_dist.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( mixture_dist.is_scalar_batch(), [batch_size], np.int32([]))), seed=mixture_seed) # We need to flatten batch dims in case mixture_dist has its own # batch dims. ids = tf.reshape( ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range( start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids = ids + offset rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids) rate = tf.reshape( rate, shape=concat_vectors([n], self._batch_shape_tensor( distributions=distributions))) return samplers.poisson( shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
def _sample_n(self, n, seed=None): distribution0 = self._get_distribution0() if self._num_steps is not None: num_steps = tf.convert_to_tensor(self._num_steps) num_steps_static = tf.get_static_value(num_steps) else: num_steps_static = tensorshape_util.num_elements( distribution0.event_shape) if num_steps_static is None: num_steps = tf.reduce_prod(distribution0.event_shape_tensor()) stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive') stateful_seed = None try: samples = distribution0.sample(n, seed=stateless_seed) is_stateful_sampler = False except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise msg = ( 'Falling back to stateful sampling for `distribution_fn(sample0)` of ' 'type `{}`. Please update to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(distribution0.name, type(distribution0), str(e))) stateful_seed = SeedStream(seed, salt='Autoregressive')() samples = distribution0.sample(n, seed=stateful_seed) is_stateful_sampler = True seed = stateful_seed if is_stateful_sampler else stateless_seed if num_steps_static is not None: for _ in range(num_steps_static): # pylint: disable=not-callable samples = self.distribution_fn(samples).sample(seed=seed) else: # pylint: disable=not-callable samples = tf.foldl( lambda s, _: self.distribution_fn(s).sample(seed=seed), elems=tf.range(0, num_steps), initializer=samples) return samples
def __call__(self, x): """Computes regularization given an input ed.RandomVariable.""" if not isinstance(x, random_variable.RandomVariable): raise ValueError('Input must be an ed.RandomVariable.') # variance = (tr( sigma_q + mu_q mu_q^T ) + 2*beta) / (omega + 2*alpha + 2) trace_covariance = tf.reduce_sum(x.distribution.variance()) trace_mean_outer_product = tf.reduce_sum(x.distribution.mean()**2) num_weights = tf.cast(tf.reduce_prod(x.shape), x.dtype) variance = ((trace_covariance + trace_mean_outer_product) + 2. * self.variance_scale) variance /= num_weights + 2. * self.variance_concentration + 2. self.stddev = tf.sqrt(variance) variance_prior = generated_random_variables.InverseGamma( self.variance_concentration, self.variance_scale) regularization = super(NormalEmpiricalBayesKLDivergence, self).__call__(x) regularization -= (self.scale_factor * variance_prior.distribution.log_prob(variance)) return regularization
def _forward(self, x, **kwargs): static_event_size = tensorshape_util.num_elements( tensorshape_util.with_rank_at_least( x.shape, self._event_ndims)[-self._event_ndims:]) if self._unroll_loop: if not static_event_size: raise ValueError( 'The final {} dimensions of `x` must be known at graph ' 'construction time if `unroll_loop=True`. `x.shape: {!r}`'. format(self._event_ndims, x.shape)) y = tf.zeros_like(x, name='y0') for _ in range(static_event_size): y = self._bijector_fn(y, **kwargs).forward(x) return y event_size = tf.reduce_prod(tf.shape(x)[-self._event_ndims:]) y0 = tf.zeros_like(x, name='y0') # call the template once to ensure creation if not tf.executing_eagerly(): _ = self._bijector_fn(y0, **kwargs).forward(y0) def _loop_body(index, y0): """While-loop body for autoregression calculation.""" # Set caching device to avoid re-getting the tf.Variable for every while # loop iteration. with tf1.variable_scope(tf1.get_variable_scope()) as vs: if vs.caching_device is None and not tf.executing_eagerly(): vs.set_caching_device(lambda op: op.device) bijector = self._bijector_fn(y0, **kwargs) y = bijector.forward(x) return index + 1, y # If the event size is available at graph construction time, we can inform # the graph compiler of the maximum number of steps. If not, # static_event_size will be None, and the maximum_iterations argument will # have no effect. _, y = tf.while_loop(cond=lambda index, _: index < event_size, body=_loop_body, loop_vars=(0, y0), maximum_iterations=static_event_size) return y
def update_state(self, data): if self.input_mean is not None: raise ValueError( "Cannot `adapt` a Normalization layer that is initialized with " "static `mean` and `variance`, " "you passed mean {} and variance {}.".format( self.input_mean, self.input_variance ) ) if not self.built: raise RuntimeError("`build` must be called before `update_state`.") data = self._standardize_inputs(data) data = tf.cast(data, self.adapt_mean.dtype) batch_mean, batch_variance = tf.nn.moments(data, axes=self._reduce_axis) batch_shape = tf.shape(data, out_type=self.count.dtype) if self._reduce_axis: batch_reduce_shape = tf.gather(batch_shape, self._reduce_axis) batch_count = tf.reduce_prod(batch_reduce_shape) else: batch_count = 1 total_count = batch_count + self.count batch_weight = tf.cast(batch_count, dtype=self.compute_dtype) / tf.cast( total_count, dtype=self.compute_dtype ) existing_weight = 1.0 - batch_weight total_mean = ( self.adapt_mean * existing_weight + batch_mean * batch_weight ) # The variance is computed using the lack-of-fit sum of squares # formula (see # https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares). total_variance = ( self.adapt_variance + (self.adapt_mean - total_mean) ** 2 ) * existing_weight + ( batch_variance + (batch_mean - total_mean) ** 2 ) * batch_weight self.adapt_mean.assign(total_mean) self.adapt_variance.assign(total_variance) self.count.assign(total_count)
def _split_and_reshape_event(self, x): """Splits and reshapes of a vector-valued event `x`.""" splits = [ tf.maximum(1, tf.reduce_prod(s)) for s in tf.nest.flatten(self._target_density.event_shape) ] x = tf.nest.pack_sequence_as(self._target_density.event_shape, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) rank = event_shape.rank if rank == 1: return part new_shape = tf.concat([tf.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, tf.cast(new_shape, tf.int32)) x = tf.nest.map_structure(_reshape_part, x, self._target_density.dtype, self._target_density.event_shape) return x
def sample(self, sample_shape=(), seed=None, name=None): with tf.name_scope(name or 'sample'): # Grab the required number of values from the provided tensors. sample_shape = dist_util.expand_to_vector(sample_shape) n = tf.cast(tf.reduce_prod(sample_shape), dtype=tf.int32) # Check that we're not trying to draw too many samples. assertions = [] will_overflow_ = tf.get_static_value(n > self.max_num_samples) if will_overflow_: raise ValueError( 'Trying to draw {} samples from a ' '`DeterministicEmpirical` instance for which only {} ' 'samples were provided.'.format( tf.get_static_value(n), tf.get_static_value(self.max_num_samples))) elif (will_overflow_ is None # Couldn't determine statically. and self.validate_args): assertions.append( tf.debugging.assert_less_equal( n, self.max_num_samples, message='Number of samples to draw ' 'from a `DeterministicEmpirical` instance must not exceed the ' 'number provided at construction.')) # Extract the appropriate number of sampled values. with tf.control_dependencies(assertions): sampled = tf.nest.map_structure(lambda x: x[:n, ...], self.values_with_sample_dim) # Reshape the values to the appropriate sample shape. return tf.nest.map_structure( lambda x: tf.reshape( x, # pylint: disable=g-long-lambda tf.concat([ tf.cast(sample_shape, tf.int32), tf.cast(tf.shape(x)[1:], tf.int32) ], axis=0)), sampled)
def _entropy(self): # Use map_fn to compute entropy for each batch separately. def _get_entropy(samples): # TODO(b/123985779): Swith to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count prob = count / self.num_samples entropy = tf.reduce_sum(input_tensor=-prob * tf.math.log(prob)) return entropy # Flatten samples for each batch. if self._event_ndims == 0: samples = tf.reshape(self.samples, [-1, self.num_samples]) else: event_size = tf.reduce_prod(input_tensor=self.event_shape_tensor()) samples = tf.reshape(self.samples, [-1, self.num_samples, event_size]) entropy = tf.map_fn(_get_entropy, samples) entropy_shape = self.batch_shape_tensor() if self.dtype.is_floating: entropy = tf.cast(entropy, self.dtype) return tf.reshape(entropy, entropy_shape)
def basis(sample_paths): """Computes polynomial basis expansion at the given sample points. Args: sample_paths: A `Tensor`s of either `flot32` or `float64` dtype and of shape `[num_samples, dim]` where `dim` has to be statically known. Returns: A `Tensor`s of shape `[degree * dim, num_samples]`. """ samples = tf.convert_to_tensor(sample_paths) dim = samples.shape.as_list()[-1] grid = tf.range(0, degree + 1, dtype=samples.dtype) samples_centered = samples - tf.math.reduce_mean(samples, axis=0) samples_centered = tf.expand_dims(samples_centered, -2) grid = tf.meshgrid(*(dim * [grid])) grid = tf.reshape(tf.stack(grid, -1), [-1, dim]) # Shape [num_samples, degree * dim] basis_expansion = tf.reduce_prod(samples_centered**grid, -1) return tf.transpose(basis_expansion)
def _entropy(self): samples = tf.convert_to_tensor(self.samples) num_samples = self._compute_num_samples(samples) entropy_shape = self._batch_shape_tensor(samples) # Flatten samples for each batch. if self._event_ndims == 0: samples = tf.reshape(samples, [-1, num_samples]) else: event_size = tf.reduce_prod(self.event_shape_tensor()) samples = tf.reshape(samples, [-1, num_samples, event_size]) # Use map_fn to compute entropy for each batch separately. def _get_entropy(samples): count = tf.raw_ops.UniqueWithCountsV2(x=samples, axis=[0]).count prob = tf.cast(count / num_samples, dtype=self.dtype) entropy = tf.reduce_sum(-prob * tf.math.log(prob)) return entropy entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype) return tf.reshape(entropy, entropy_shape)
def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. batch_size = self.batch_shape.num_elements() if batch_size is None: batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor()) # We need to "sample extra" from the mixture distribution if it doesn't # already specify a probs vector for each batch coordinate. # We only support this kind of reduced broadcasting, i.e., there is exactly # one probs vector for all batch dims or one for each. stream = seed_stream.SeedStream( seed, salt="PoissonLogNormalQuadratureCompound") ids = self._mixture_distribution.sample(sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.mixture_distribution.is_scalar_batch(), [batch_size], np.int32([]))), seed=stream()) # We need to flatten batch dims in case mixture_distribution has its own # batch dims. ids = tf.reshape(ids, shape=concat_vectors([n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), np.int32([-1])))) # Stride `quadrature_size` for `batch_size` number of times. offset = tf.range(start=0, limit=batch_size * self._quadrature_size, delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids) rate = tf.reshape(rate, shape=concat_vectors([n], self.batch_shape_tensor())) return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
def basis(sample_paths, time_index): """Computes polynomial basis expansion at the given sample points. Args: sample_paths: A `Tensor` of either `flaot32` or `float64` dtype and of either shape `[num_samples, num_times, dim]` or `[batch_size, num_samples, num_times, dim]`. time_index: An integer scalar `Tensor` that corresponds to the time coordinate at which the basis function is computed. Returns: A `Tensor`s of shape `[batch_size, (degree + 1)**dim, num_samples]`. """ sample_paths = tf.convert_to_tensor(sample_paths, name="sample_paths") if sample_paths.shape.rank == 3: sample_paths = tf.expand_dims(sample_paths, axis=0) shape = tf.shape(sample_paths) num_samples = shape[1] batch_size = shape[0] dim = sample_paths.shape[-1] # Dimension should statically known # Shape [batch_size, num_samples, 1, dim] slice_samples = tf.slice(sample_paths, [0, 0, time_index, 0], [batch_size, num_samples, 1, dim]) # Shape [batch_size, num_samples, 1, dim] samples_centered = slice_samples - tf.math.reduce_mean( slice_samples, axis=1, keepdims=True) grid = tf.range(degree + 1, dtype=samples_centered.dtype) # Creates a grid of 'power' expansions, i.e., a `Tensor` of shape # [(degree + 1)**dim, dim] with entries [k_1, .., k_dim] where ## 0 <= k_i <= dim. grid = tf.meshgrid(*(dim * [grid])) # Shape [(degree + 1)**3, dim] grid = tf.reshape(tf.stack(grid, -1), [-1, dim]) # `samples_centered` has shape [batch_size, num_samples, 1, dim], # `samples_centered**grid` has shape # `[batch_size, num_samples, (degree + 1)**dim, dim]` # so that the output shape is `[batch_size, num_samples, (degree + 1)**dim]` basis_expansion = tf.reduce_prod(samples_centered**grid, axis=-1) return tf.transpose(basis_expansion, [0, 2, 1])
def testRank1ResNetV1(self, alpha_initializer, gamma_initializer, random_sign_init, ensemble_size): tf.random.set_seed(83922) dataset_size = 10 batch_size = 6 input_shape = (32, 32, 2 ) # TODO(dusenberrymw): (32, 32, 1) doesn't work... num_classes = 2 features = tf.random.normal((dataset_size, ) + input_shape) coeffs = tf.random.normal([tf.reduce_prod(input_shape), num_classes]) net = tf.reshape(features, [dataset_size, -1]) logits = tf.matmul(net, coeffs) labels = tf.random.categorical(logits, 1) dataset = tf.data.Dataset.from_tensor_slices((features, labels)) dataset = dataset.repeat().shuffle(dataset_size).batch(batch_size) model = resnet_cifar_model.rank1_resnet_v1( input_shape=input_shape, depth=8, num_classes=num_classes, width_multiplier=1, alpha_initializer=alpha_initializer, gamma_initializer=gamma_initializer, alpha_regularizer=None, gamma_regularizer=None, use_additive_perturbation=False, ensemble_size=ensemble_size, random_sign_init=-0.5, dropout_rate=0.) model.compile('adam', loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True)) history = model.fit(dataset, steps_per_epoch=dataset_size // batch_size, epochs=2) loss_history = history.history['loss'] self.assertAllGreaterEqual(loss_history, 0.)
def _quasi_uniform( dim, sample_shape, random_type, dtype, seed=None, **kwargs): """Quasi random draws from a uniform distribution on [0, 1).""" # Shape of the output output_shape = tf.concat([sample_shape] + [[dim]], -1) # Number of quasi random samples num_samples = tf.reduce_prod(sample_shape) # Number of initial low discrepancy sequence numbers to skip if 'skip' in kwargs: skip = kwargs['skip'] else: skip = 0 if random_type == RandomType.SOBOL: # Shape [num_samples, dim] of the Sobol samples low_discrepancy_seq = sobol.sample( dim=dim, num_results=num_samples, skip=skip, dtype=dtype) # TODO(b/148005344): Remove after tf.reshape after the bug is fixed low_discrepancy_seq = tf.reshape(low_discrepancy_seq, [num_samples, dim]) else: # HALTON or HALTON_RANDOMIZED random_dtype if 'randomization_params' in kwargs: randomization_params = kwargs['randomization_params'] else: randomization_params = None randomized = random_type == RandomType.HALTON_RANDOMIZED # Shape [num_samples, dim] of the Sobol samples low_discrepancy_seq, _ = halton.sample( dim=dim, sequence_indices=tf.range(skip, skip + num_samples), randomized=randomized, randomization_params=randomization_params, seed=seed, dtype=dtype) return tf.reshape(low_discrepancy_seq, output_shape)
def _entropy(self): samples = tf.convert_to_tensor(self.samples) num_samples = self._compute_num_samples(samples) entropy_shape = self._batch_shape_tensor(samples) # Flatten samples for each batch. if self._event_ndims == 0: samples = tf.reshape(samples, [-1, num_samples]) else: event_size = tf.reduce_prod(self.event_shape_tensor()) samples = tf.reshape(samples, [-1, num_samples, event_size]) # Use map_fn to compute entropy for each batch separately. def _get_entropy(samples): # TODO(b/123985779): Switch to tf.unique_with_counts_v2 when exposed count = gen_array_ops.unique_with_counts_v2(samples, axis=[0]).count prob = tf.cast(count / num_samples, dtype=self.dtype) entropy = tf.reduce_sum(-prob * tf.math.log(prob)) return entropy entropy = tf.map_fn(_get_entropy, samples, dtype=self.dtype) return tf.reshape(entropy, entropy_shape)
def easom(z): """The value of the two dimensional Easom function. The Easom function is a standard optimization test function. It has a single global minimum at (pi, pi) which is located inside a deep funnel. The expression for the function is: ```None f(x, y) = -cos(x) cos(y) exp(-(x-pi)**2 - (y-pi)**2) ``` Args: z: `Tensor` of shape [2] and real dtype. The argument at which to evaluate the function. Returns: value: Scalar real `Tensor`. The value of the Easom function at the supplied argument. """ f1 = tf.reduce_prod(tf.cos(z), axis=-1) f2 = tf.exp(-tf.reduce_sum((z - np.pi)**2, axis=-1)) return -f1 * f2