def consume(running_stat, elems, chunk_axis=None): def body(running_stat, elem): if chunk_axis is None: return running_stat.update(elem) else: return running_stat.update(elem, axis=chunk_axis) return tf.foldl(body, elems, running_stat)
def _sample_n(self, n, seed=None): distribution0 = self._get_distribution0() if self._num_steps is not None: num_steps = tf.convert_to_tensor(self._num_steps) num_steps_static = tf.get_static_value(num_steps) else: num_steps_static = tensorshape_util.num_elements( distribution0.event_shape) if num_steps_static is None: num_steps = tf.reduce_prod(distribution0.event_shape_tensor()) seed = SeedStream(seed, salt='Autoregressive')() samples = distribution0.sample(n, seed=seed) if num_steps_static is not None: for _ in range(num_steps_static): # pylint: disable=not-callable samples = self.distribution_fn(samples).sample(seed=seed) else: samples = tf.foldl( # pylint: disable=not-callable lambda s, _: self.distribution_fn(s).sample(seed=seed), elems=tf.range(0, num_steps), initializer=samples) return samples
def no_pivot_ldl(matrix, name='no_pivot_ldl'): """Non-pivoted batched LDL factorization. Performs the LDL factorization, using the outer product algorithm from [1]. No pivoting (or block pivoting) is done, so this should be less stable than e.g. Bunch-Kaufman sytrf. This is implemented as a tf.foldl, so should have gradients and be accelerator-friendly, but is not particularly performant. If compiling with XLA, make sure any surrounding GradientTape is also XLA-compiled (b/193584244). #### References [1]: Gene H. Golub, Charles F. Van Loan. Matrix Computations, 4th ed., 2013. Args: matrix: A batch of symmetric square matrices, with shape `[..., n, n]`. name: Python `str` name prefixed to Ops created by this function. Default value: 'no_pivot_ldl'. Returns: triangular_factor: The unit lower triangular L factor of the LDL factorization of `matrix`, with the same shape `[..., n, n]`. Callers should check for `nans` and other indicators of instability. diag: The diagonal from the LDL factorization, with shape `[..., n]`. """ with tf.name_scope(name) as name: matrix = tf.convert_to_tensor(matrix) triangular_factor = tf.linalg.band_part(matrix, num_lower=-1, num_upper=0) # TODO(b/182276317) Deal with dynamic ranks better. slix = _Slice2Idx(triangular_factor) def fn(triangular_factor, i): column_head = triangular_factor[..., i, i, tf.newaxis] column_tail = triangular_factor[..., i+1:, i] rescaled_tail = column_tail / column_head triangular_factor = tf.tensor_scatter_nd_update( triangular_factor, slix[..., i+1:, i], rescaled_tail) triangular_factor = tf.tensor_scatter_nd_sub( triangular_factor, slix[..., i+1:, i+1:], tf.linalg.band_part( tf.einsum('...i,...j->...ij', column_tail, rescaled_tail), num_lower=-1, num_upper=0)) return triangular_factor triangular_factor = tf.foldl( fn=fn, elems=tf.range(tf.shape(triangular_factor)[-1]), initializer=triangular_factor) diag = tf.linalg.diag_part(triangular_factor) triangular_factor = tf.linalg.set_diag( triangular_factor, tf.ones_like(diag)) return triangular_factor, diag
def _do_flips(): state = sampler._initialize_sampler_state( targets=targets, nonzeros=initial_nonzeros, observation_noise_variance=1.) def _do_flip(state, i): new_state = sampler._flip_feature(state, tf.gather(flip_idxs, i)) return mcmc_util.choose(tf.gather(should_flip, i), new_state, state) return tf.foldl(_do_flip, elems=tf.range(num_flips), initializer=state)
def _log_prob(self, value): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `batch_shape` that defines the shape of the result. observation_tensor_shape = tf.shape(value) observation_distribution = self.observation_distribution underlying_event_rank = tf.size( observation_distribution.event_shape_tensor()) observation_batch_shape = observation_tensor_shape[:-1 - underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) num_states = self.transition_distribution.batch_shape_tensor()[-1] log_init = _extract_log_probs(num_states, self.initial_distribution) # log_init :: batch_shape num_states log_init = tf.broadcast_to( log_init, tf.concat([batch_shape, [num_states]], axis=0)) log_transition = _extract_log_probs(num_states, self.transition_distribution) # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = observation_tensor_shape[ -1 - underlying_event_rank:] working_obs = tf.broadcast_to( value, tf.concat([batch_shape, observation_event_shape], axis=0)) # working_obs :: batch_shape observation_event_shape r = underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0) # working_obs :: num_steps batch_shape underlying_event_shape working_obs = tf.expand_dims(working_obs, -1 - r) # working_obs :: num_steps batch_shape 1 underlying_event_shape observation_probs = observation_distribution.log_prob(working_obs) # observation_probs :: num_steps batch_shape num_states def forward_step(log_prev_step, log_prob_observation): return _log_vector_matrix(log_prev_step, log_transition) + log_prob_observation fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init) # fwd_prob :: batch_shape num_states log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1) # log_prob :: batch_shape return log_prob
def _log_prob(self, value): with tf.control_dependencies(self._runtime_assertions): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `batch_shape` that defines the shape of the result. observation_tensor_shape = tf.shape(input=value) observation_batch_shape = observation_tensor_shape[:-1 - self. _underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) log_init = tf.broadcast_to( self._log_init, tf.concat([batch_shape, [self._num_states]], axis=0)) # log_init :: batch_shape num_states log_transition = self._log_trans # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = observation_tensor_shape[ -1 - self._underlying_event_rank:] working_obs = tf.broadcast_to( value, tf.concat([batch_shape, observation_event_shape], axis=0)) # working_obs :: batch_shape observation_event_shape r = self._underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl working_obs = distribution_util.move_dimension( working_obs, -1 - r, 0)[..., tf.newaxis] # working_obs :: num_steps batch_shape underlying_event_shape observation_probs = ( self._observation_distribution.log_prob(working_obs)) def forward_step(log_prev_step, log_prob_observation): return _log_vector_matrix( log_prev_step, log_transition) + log_prob_observation fwd_prob = tf.foldl(forward_step, observation_probs, initializer=log_init) # fwd_prob :: batch_shape num_states log_prob = tf.reduce_logsumexp(input_tensor=fwd_prob, axis=-1) # log_prob :: batch_shape return log_prob
def _sample_n(self, n, seed=None): distribution0 = self._get_distribution0() if self._num_steps is not None: num_steps = tf.convert_to_tensor(self._num_steps) num_steps_static = tf.get_static_value(num_steps) else: num_steps_static = tensorshape_util.num_elements( distribution0.event_shape) if num_steps_static is None: num_steps = tf.reduce_prod(distribution0.event_shape_tensor()) stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive') stateful_seed = None try: samples = distribution0.sample(n, seed=stateless_seed) is_stateful_sampler = False except TypeError as e: if ('Expected int for argument' not in str(e) and TENSOR_SEED_MSG_PREFIX not in str(e)): raise msg = ( 'Falling back to stateful sampling for `distribution_fn(sample0)` of ' 'type `{}`. Please update to use `tf.random.stateless_*` RNGs. ' 'This fallback may be removed after 20-Aug-2020. ({})') warnings.warn( msg.format(distribution0.name, type(distribution0), str(e))) stateful_seed = SeedStream(seed, salt='Autoregressive')() samples = distribution0.sample(n, seed=stateful_seed) is_stateful_sampler = True seed = stateful_seed if is_stateful_sampler else stateless_seed if num_steps_static is not None: for _ in range(num_steps_static): # pylint: disable=not-callable samples = self.distribution_fn(samples).sample(seed=seed) else: # pylint: disable=not-callable samples = tf.foldl( lambda s, _: self.distribution_fn(s).sample(seed=seed), elems=tf.range(0, num_steps), initializer=samples) return samples
def _log_prob(self, value): # The argument `value` is a tensor of sequences of observations. # `observation_batch_shape` is the shape of that tensor with the # sequence part removed. # `observation_batch_shape` is then broadcast to the full batch shape # to give the `batch_shape` that defines the shape of the result. observation_tensor_shape = ps.shape(value) observation_distribution = self.observation_distribution underlying_event_rank = ps.size( observation_distribution.event_shape_tensor()) observation_batch_shape = observation_tensor_shape[ :-1 - underlying_event_rank] # value :: observation_batch_shape num_steps observation_event_shape batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape, self.batch_shape_tensor()) num_states = self.transition_distribution.batch_shape_tensor()[-1] log_init = _extract_log_probs(num_states, self.initial_distribution) # log_init :: batch_shape num_states log_init = tf.broadcast_to(log_init, ps.concat([batch_shape, [num_states]], axis=0)) log_transition = _extract_log_probs(num_states, self.transition_distribution) # `observation_event_shape` is the shape of each sequence of observations # emitted by the model. observation_event_shape = observation_tensor_shape[ -1 - underlying_event_rank:] working_obs = tf.broadcast_to(value, ps.concat([batch_shape, observation_event_shape], axis=0)) # working_obs :: batch_shape observation_event_shape r = underlying_event_rank # Move index into sequence of observations to front so we can apply # tf.foldl if self._time_varying_observation_distribution: working_obs = tf.expand_dims(working_obs, -1 - r) # working_obs :: batch_shape num_steps 1 underlying_event_shape observation_probs = observation_distribution.log_prob(working_obs) # observation_probs :: batch_shape num_steps num_states observation_probs = distribution_util.move_dimension( observation_probs, -2, 0) # observation_probs :: num_steps batch_shape num_states else: working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0) # working_obs :: num_steps batch_shape underlying_event_shape working_obs = tf.expand_dims(working_obs, -1 - r) # working_obs :: num_steps batch_shape 1 underlying_event_shape observation_probs = observation_distribution.log_prob(working_obs) # observation_probs :: num_steps batch_shape num_states def forward_step(log_prev_step, log_prob_observation): return _log_vector_matrix(log_prev_step, log_transition) + log_prob_observation # TODO(davmre): Delete this warning after Dec 31, 2020. warnings.warn( 'HiddenMarkovModel.log_prob in TFP versions < 0.12.0 had a bug ' 'in which the transition model was applied prior to the initial step. ' 'This bug has been fixed. You may observe a slight change in behavior.') fwd_prob = tf.foldl(forward_step, observation_probs[1:], initializer=log_init + observation_probs[0]) # fwd_prob :: batch_shape num_states log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1) # log_prob :: batch_shape return log_prob