def _reduce_multiple_steps(): """Perform `reduce_max` operation when `num_steps` > 1.""" def forward_step(previous_step_pair, log_prob_observation): log_prob_previous = previous_step_pair[0] log_prob = ( log_prob_previous[..., tf.newaxis] + self._log_trans + log_prob_observation[..., tf.newaxis, :]) most_likely_given_successor = tf.argmax(log_prob, axis=-2) max_log_p_given_successor = tf.reduce_max(log_prob, axis=-2) return (max_log_p_given_successor, most_likely_given_successor) forward_log_probs, all_most_likely_given_successor = tf.scan( forward_step, observation_log_probs[1:], initializer=(log_prob, tf.zeros(tf.shape(log_prob), dtype=tf.int64)), name="forward_log_probs") most_likely_end = tf.argmax(forward_log_probs[-1], axis=-1) # We require the operation that gives C from A and B where # C[i...j] = A[i...j, B[i...j]] # and A = most_likely_given_successor # B = most_likely_successor. # tf.gather requires indices of known shape so instead we use # reduction with tf.one_hot(B) to pick out elements from B def backward_step(most_likely_successor, most_likely_given_successor): return tf.reduce_sum( (most_likely_given_successor * tf.one_hot(most_likely_successor, self._num_states, dtype=tf.int64)), axis=-1) backward_scan = tf.scan( backward_step, all_most_likely_given_successor, most_likely_end, reverse=True) most_likely_sequences = tf.concat( [backward_scan, [most_likely_end]], axis=0) return distribution_util.move_dimension( most_likely_sequences, 0, -1)
def _two_loop_algorithm(): """L-BFGS two-loop algorithm.""" # Correction pairs are always appended to the end, so only the latest # `num_elements` vectors have valid position/gradient deltas. position_deltas = state.position_deltas[-num_elements:] gradient_deltas = state.gradient_deltas[-num_elements:] # Pre-compute all `inv_rho[i]`s. inv_rhos = tf.reduce_sum(input_tensor=gradient_deltas * position_deltas, axis=-1) def first_loop(acc, args): _, q_direction = acc position_delta, gradient_delta, inv_rho = args alpha = tf.reduce_sum(input_tensor=position_delta * q_direction, axis=-1) / inv_rho direction_delta = tf.expand_dims(alpha, axis=-1) * gradient_delta return (alpha, q_direction - direction_delta) # Run first loop body computing and collecting `alpha[i]`s, while also # computing the updated `q_direction` at each step. zero = tf.zeros_like(inv_rhos[0]) alphas, q_directions = tf.scan( first_loop, [position_deltas, gradient_deltas, inv_rhos], initializer=(zero, state.objective_gradient), reverse=True) # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`. gamma_k = inv_rhos[-1] / tf.reduce_sum( input_tensor=gradient_deltas[-1] * gradient_deltas[-1], axis=-1) r_direction = tf.expand_dims(gamma_k, axis=-1) * q_directions[0] def second_loop(r_direction, args): alpha, position_delta, gradient_delta, inv_rho = args beta = tf.reduce_sum(input_tensor=gradient_delta * r_direction, axis=-1) / inv_rho direction_delta = tf.expand_dims(alpha - beta, axis=-1) * position_delta return r_direction + direction_delta # Finally, run second loop body computing the updated `r_direction` at each # step. r_directions = tf.scan( second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos], initializer=r_direction) return -r_directions[-1]
def _sample_and_log_prob_helper(self, sample_shape, seed=None, compute_log_prob=False): """Draws samples from the chain and optionally accumulates the log_prob.""" prior_seed, loop_seed = samplers.split_seed(n=2, seed=seed, salt='markov_chain_sample') if compute_log_prob: sample_attr = 'experimental_sample_and_log_prob' extract_sample_fn = lambda x_and_lp: x_and_lp[0] extract_lp_fn = lambda x_and_lp: self._sum_fn(x_and_lp[1], axis=0) else: sample_attr = 'sample' extract_sample_fn = lambda x: x extract_lp_fn = lambda x: 0. prior_result = getattr(self.initial_state_prior, sample_attr)(sample_shape, seed=prior_seed) loop_body = _make_sample_loop_body(self.transition_fn, sample_attr=sample_attr, extract_sample_fn=extract_sample_fn) _, results = tf.scan(loop_body, elems=tf.range(1, self.num_steps), initializer=(loop_seed, prior_result)) # Concatenate prior sample (and lp) with remaining samples (and lps). results = tf.nest.map_structure(concat_initial, prior_result, results) samples, lp = extract_sample_fn(results), extract_lp_fn(results) # Move leftmost `num_steps` dimension into the event shape. samples = move_dimensions(samples, 0, self._step_axes()) return samples, lp
def reconstruct_trajectories(particles, parent_indices, name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): indices_shape = prefer_static.shape(parent_indices) batch_shape, num_trajectories = indices_shape[1:-1], indices_shape[-1] batch_rank = prefer_static.rank_from_shape(batch_shape) # Walk backwards to compute the ancestor of each final particle at time t. final_indices = tf.broadcast_to(tf.range(0, num_trajectories), indices_shape[1:]) ancestor_indices = tf.scan( fn=lambda ancestor, parent: tf.gather( # pylint: disable=g-long-lambda parent, ancestor, axis=batch_rank, batch_dims=batch_rank), elems=parent_indices[1:], initializer=final_indices, reverse=True) ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0) return tf.nest.map_structure( lambda part: tf.gather( part, ancestor_indices, # pylint: disable=g-long-lambda axis=batch_rank + 1, batch_dims=batch_rank + 1), particles)
def _marginal_hidden_probs(self): """Compute marginal pdf for each individual observable.""" initial_log_probs = tf.broadcast_to( self._log_init, tf.concat([self.batch_shape_tensor(), [self._num_states]], axis=0)) # initial_log_probs :: batch_shape num_states if self._num_steps > 1: transition_log_probs = self._log_trans def forward_step(log_probs, _): return _log_vector_matrix(log_probs, transition_log_probs) dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name="forward_log_probs") forward_log_probs = tf.concat( [[initial_log_probs], forward_log_probs], axis=0) else: forward_log_probs = initial_log_probs[tf.newaxis, ...] # returns :: num_steps batch_shape num_states return tf.exp(forward_log_probs)
def _scan_multiple_steps(): """Perform `scan` operation when `num_steps` > 1.""" transition_log_probs = _extract_log_probs(num_states, self.transition_distribution) def forward_step(log_probs, _): result = _log_vector_matrix(log_probs, transition_log_probs) # We know that `forward_step` must preserve the shape of the # tensor of probabilities of each state. This is because # the transition matrix must be square. But TensorFlow might # not know this so we explicitly tell it that the result has the # same shape. tensorshape_util.set_shape(result, log_probs.shape) return result dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name='forward_log_probs') result = tf.concat([[initial_log_probs], forward_log_probs], axis=0) return result
def call(self, inputs: tf.Tensor, initial_state: tf.Tensor): """Inputs is of shape [batch, seq_length, num_filters].""" w = tf.clip_by_value(self._weights, clip_value_min=0.0, clip_value_max=1.0) result = tf.scan(lambda a, x: w * x + (1.0 - w) * a, tf.transpose(inputs, (1, 0, 2)), initializer=initial_state) return tf.transpose(result, (1, 0, 2))
def fit_with_gibbs_sampling(model, observed_time_series, num_results=2000, num_warmup_steps=200, initial_state=None, seed=None): """Fits parameters for an STS model using Gibbs sampling.""" if not hasattr(model, 'supports_gibbs_sampling'): raise ValueError( 'This STS model does not support Gibbs sampling. Models ' 'for Gibbs sampling must be created using the ' 'method `build_model_for_gibbs_fitting`.') [observed_time_series, is_missing] = sts_util.canonicalize_observed_time_series_with_mask( observed_time_series) dtype = observed_time_series.dtype # The canonicalized time series always has trailing dimension `1`, # because although LinearGaussianSSMs support vector observations, STS models # describe scalar time series only. For our purposes it'll be cleaner to # remove this dimension. observed_time_series = observed_time_series[..., 0] batch_shape = prefer_static.shape(observed_time_series)[:-1] # Treat a LocalLevel model as the special case of LocalLinearTrend where # the slope_scale is always zero. initial_slope_scale = 0. initial_slope = 0. if isinstance(model.components[0], sts.LocalLinearTrend): initial_slope_scale = 1. * tf.ones(batch_shape, dtype=dtype) initial_slope = tf.zeros_like(observed_time_series) if initial_state is None: initial_state = GibbsSamplerState( observation_noise_scale=tf.ones(batch_shape, dtype=dtype), level_scale=tf.ones(batch_shape, dtype=dtype), slope_scale=initial_slope_scale, weights=tf.zeros(prefer_static.concat( [batch_shape, _get_design_matrix(model).shape[-1:]], axis=0), dtype=dtype), level=tf.zeros_like(observed_time_series), slope=initial_slope, seed=None) # Set below. if isinstance(seed, six.integer_types): tf.random.set_seed(seed) # Always use the passed-in `seed` arg, ignoring any seed in the initial state. initial_state = initial_state._replace( seed=samplers.sanitize_seed(seed, salt='initial_GibbsSamplerState')) sampler_loop_body = _build_sampler_loop_body(model, observed_time_series, is_missing) samples = tf.scan(sampler_loop_body, np.arange(num_warmup_steps + num_results), initial_state) return tf.nest.map_structure(lambda x: x[num_warmup_steps:], samples)
def _scan_multiple_steps(): dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0)
def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) if seed is not None: # Force parallel_iterations to 1 to ensure reproducibility # b/139210489 hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state, parallel_iterations=1) else: # Invoke default parallel_iterations behavior hidden_states = tf.scan(generate_step, dummy_index, initializer=init_state) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0)
def _scan_multiple_steps_forwards(): def forward_step(log_previous_step, log_prob_observation): return _log_vector_matrix(log_previous_step, log_transition) + log_prob_observation forward_log_probs = tf.scan(forward_step, observation_log_probs[1:], initializer=log_prob, name='forward_log_probs') return ps.concat([[log_prob], forward_log_probs], axis=0)
def _scan_multiple_steps(): """Take multiple steps with tf.scan.""" dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) hidden_states, _ = tf.scan(generate_step, dummy_index, initializer=(init_state, scan_seed)) # TODO(b/115618503): add/use prepend_initializer to tf.scan return tf.concat([[init_state], hidden_states], axis=0)
def test_scan_with_struct_elems(self): elems = (np.arange(5).astype(np.int32), np.arange(10).astype(np.int32).reshape(5, 2)) init = (np.int32([7, 8]), np.int32([9, 1])) self.assertAllEqual( self.evaluate(tf.scan( lambda x, y: (x[0] + y[0], x[1] - y[1]), elems, initializer=init)), nptf.scan( lambda x, y: (x[0] + y[0], x[1] - y[1]), elems, initializer=init))
def fit_with_gibbs_sampling(model, observed_time_series, num_results=2000, num_warmup_steps=200, compile_steps_with_xla=False, initial_state=None, seed=None): """Fits parameters for an STS model using Gibbs sampling.""" if not hasattr(model, 'supports_gibbs_sampling'): raise ValueError('This STS model does not support Gibbs sampling. Models ' 'for Gibbs sampling must be created using the ' 'method `build_model_for_gibbs_fitting`.') [ observed_time_series, is_missing ] = sts_util.canonicalize_observed_time_series_with_mask( observed_time_series) dtype = observed_time_series.dtype # The canonicalized time series always has trailing dimension `1`, # because although LinearGaussianSSMs support vector observations, STS models # describe scalar time series only. For our purposes it'll be cleaner to # remove this dimension. observed_time_series = observed_time_series[..., 0] batch_shape = prefer_static.shape(observed_time_series)[:-1] if initial_state is None: initial_state = GibbsSamplerState( observation_noise_scale=tf.ones(batch_shape, dtype=dtype), level_scale=tf.ones(batch_shape, dtype=dtype), weights=tf.zeros(prefer_static.concat([ batch_shape, _get_design_matrix(model).shape[-1:]], axis=0), dtype=dtype), level=tf.zeros_like(observed_time_series), seed=None) # Set below. if seed and isinstance(seed, six.integer_types): tf.random.set_seed(seed) # Always use the passed-in `seed` arg, ignoring any seed in the initial state. seeded_state = initial_state._asdict() seeded_state['seed'] = samplers.sanitize_seed( seed, salt='initial_GibbsSamplerState') initial_state = GibbsSamplerState(**seeded_state) sampler_loop_body = _build_sampler_loop_body( model, observed_time_series, is_missing, compile_steps_with_xla=compile_steps_with_xla, seed=seed) # This is still an `int` seed, because the InverseGamma # sampler currently requires stateful semantics. samples = tf.scan(sampler_loop_body, np.arange(num_warmup_steps + num_results), initial_state) return tf.nest.map_structure(lambda x: x[num_warmup_steps:], samples)
def test_scan_with_struct(self): elems = np.arange(5).astype(np.int32) self.assertAllEqual( self.evaluate( tf.scan(lambda x, y: (x[0] + y, x[1] - y), elems, initializer=(7, 3))), nptf.scan(lambda x, y: (x[0] + y, x[1] - y), elems, initializer=(7, 3)))
def reconstruct_trajectories(particles, parent_indices, name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): # Walk backwards to compute the ancestor of each final particle at time t. final_indices = _dummy_indices_like(parent_indices[-1]) ancestor_indices = tf.scan( fn=lambda ancestor, parent: _batch_gather(parent, ancestor, axis=0), elems=parent_indices[1:], initializer=final_indices, reverse=True) ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0) return tf.nest.map_structure( lambda part: _batch_gather(part, ancestor_indices, axis=1), particles)
def generate_exchanges(self, exchange_proposed_fn, num_replica, seed): def _scan_fn(*_): exchange = exchange_proposed_fn(num_replica, seed) flat_replicas = tf.reshape(exchange, [-1]) with tf.control_dependencies([ tf1.assert_equal( tf.size(input=flat_replicas), tf.size(input=tf.unique(flat_replicas)[0])), tf1.assert_greater_equal(flat_replicas, 0), tf1.assert_less(flat_replicas, num_replica), ]): return tf.shape(input=exchange)[0] return self.evaluate( tf.scan(_scan_fn, tf.range(1000), initializer=0, parallel_iterations=1))
def _scan_multiple_steps(): """Perform `scan` operation when `num_steps` > 1.""" transition_log_probs = self._log_trans def forward_step(log_probs, _): return _log_vector_matrix(log_probs, transition_log_probs) dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32) forward_log_probs = tf.scan(forward_step, dummy_index, initializer=initial_log_probs, name="forward_log_probs") return tf.concat([[initial_log_probs], forward_log_probs], axis=0)
def reconstruct_trajectories(particles, parent_indices, name=None): """Reconstructs the ancestor trajectory that generated each final particle.""" with tf.name_scope(name or 'reconstruct_trajectories'): # Walk backwards to compute the ancestor of each final particle at time t. final_indices = smc_kernel._dummy_indices_like(parent_indices[-1]) # pylint: disable=protected-access ancestor_indices = tf.scan( fn=lambda ancestor, parent: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda parent, ancestor, axis=0), elems=parent_indices[1:], initializer=final_indices, reverse=True) ancestor_indices = tf.concat([ancestor_indices, [final_indices]], axis=0) return tf.nest.map_structure( lambda part: mcmc_util.index_remapping_gather( # pylint: disable=g-long-lambda part, ancestor_indices, axis=1, indices_axis=1), particles)
def _scan_multiple_steps_backwards(): """Perform `scan` operation when `num_steps` > 1.""" def backward_step(log_previous_step, log_prob_observation): return _log_matrix_vector( log_transition, log_prob_observation + log_previous_step) backward_log_adjoint_probs = tf.scan( backward_step, observation_log_probs[1:], initializer=log_adjoint_prob, reverse=True, name='backward_log_adjoint_probs') return tf.concat([backward_log_adjoint_probs, [log_adjoint_prob]], axis=0)
def test_samples_from_weights_prior(self): nonzero_prior_prob = 0.7 num_outputs, num_features = 200, 4 # Setting the design matrix to zero, the targets provide no information # about weights, so the sampler should sample from the prior. design_matrix = tf.zeros([num_outputs, num_features]) targets = 0.42 * samplers.normal([num_outputs], seed=test_util.test_seed()) sampler = spike_and_slab.SpikeSlabSampler( design_matrix=design_matrix, weights_prior_precision=tf.eye(num_features), nonzero_prior_prob=nonzero_prior_prob) # Draw 100 posterior samples. Since all state needed for the # internal feature sweep is a function of the sparsity pattern, it's # sufficient to pass the sparsity pattern (by way of the weights) as # the outer-loop state. @tf.function(autograph=False) def loop_body(var_weights_seed, _): _, weights, seed = var_weights_seed seed, next_seed = samplers.split_seed(seed, n=2) variance, weights = sampler.sample_noise_variance_and_weights( initial_nonzeros=tf.not_equal(weights, 0.), targets=targets, seed=seed) return variance, weights, next_seed init_seed = test_util.test_seed(sampler_type='stateless') variance_samples, weight_samples, _ = tf.scan( fn=loop_body, initializer=(1., tf.ones([num_features]), init_seed), elems=tf.range(100)) # With the default (relatively uninformative) prior, the noise variance # posterior mean should be close to the most-likely value. self.assertAllClose(tf.reduce_mean(variance_samples), tf.math.reduce_std(targets)**2, atol=0.03) # Since there is no evidence for the weights, the sparsity of our samples # should match the prior. nonzero_weight_samples = tf.cast(tf.not_equal(weight_samples, 0.), tf.float32) self.assertAllClose(nonzero_prior_prob, tf.reduce_mean(nonzero_weight_samples), atol=0.03)
def log_volatility_noncentered_fn(white_noise_shock_scale, persistence_of_volatility): """Noncentered parameterization of log_volatility random variable.""" # The non-centered parameterization for log_volatility improves geometry # but is slower (catastrophically so if FFT is not used). std_log_volatility = yield root( tfd.Sample( tfd.Normal(0., 1.), num_timesteps, name='std_log_volatility', )) if use_fft: return (white_noise_shock_scale[..., tf.newaxis] * _fft_conv_center(std_log_volatility, persistence_of_volatility)) else: log_volatility = (std_log_volatility * white_noise_shock_scale[..., tf.newaxis]) log_volatility_0 = ( log_volatility[..., 0] / tf.sqrt(1 - persistence_of_volatility**2)) # Make the time axis be first, for scan to work. log_volatility = distribution_util.move_dimension( log_volatility, -1, 0) # I.e. # log_volatility[t] += (persistence_of_volatility * # log_volatility[t-1]) log_volatility = tf.concat( [ log_volatility_0[tf.newaxis], tf.scan( lambda v_prev, v: persistence_of_volatility * v_prev + v, log_volatility[1:], log_volatility_0) ], axis=0, ) return distribution_util.move_dimension( log_volatility, 0, -1)
def test_categorical_resampler_chi2(self): strm = test_util.test_seed_stream() # Test categorical resampler using chi-squared test. if self.use_xla and tf.executing_eagerly(): self.skipTest('No need to test XLA under all execution regimes.') num_probs = 50 num_distributions = 3 unnormalized_probs = tfd.Uniform(low=self.dtype(0), high=self.dtype(1.)).sample( [num_distributions, num_probs], seed=strm) probs = unnormalized_probs / tf.reduce_sum( unnormalized_probs, axis=-1, keepdims=True) # chi-squared test is valid as long as `num_samples` is # large compared to `num_probs`. num_particles = 10000 num_samples = 2 sample = self.maybe_compiler(resample_independent)(tf.math.log( dist_util.move_dimension(probs, source_idx=-1, dest_idx=0)), num_particles, [num_samples], seed=strm) elems = tf.range(num_probs) initializer = tf.zeros([num_samples, num_distributions], dtype=sample.dtype) counts = tf.scan( lambda _, x: tf.reduce_sum( # pylint: disable=g-long-lambda tf.cast(tf.equal(sample, x), sample.dtype), axis=0), elems, initializer) counts = dist_util.move_dimension(tf.cast(counts, self.dtype), source_idx=0, dest_idx=-1) expected_samples = probs * num_particles chi2 = tf.reduce_sum((counts - expected_samples)**2 / expected_samples, axis=-1) self.assertAllLess( tfd.Chi2(df=self.dtype(num_probs - 1)).cdf(chi2), 0.99995)
def _apply_forward_scan(self, fn, x0, xs): """Runs the chain forward, accumulating `fn(b, x, y)` vals at every step. Args: fn: Callable with signature `result = fn(b, x, y)`. x0: Structure of initial state `Tensors`, each of shape `concat([[batch_shape], unconstrained_prior_event_shape])`. xs: Structure of `Tensors`, each of shape `concat([[batch_shape], [num_steps - 1], unconstrained_transition_event_shape])`. Returns: fs: Result `Tensor` of shape `concat([[num_steps], batch_shape, result_shape])`, where `result_shape` is the shape of the result from an unbatched call to `fn`. """ xs_step_axes = tf.nest.map_structure( lambda nd: -nd, self.transition_bijector.inverse_event_ndims( # Outputs `y` have the num_steps axis at `-inverse_min_event_ndims`. self.inverse_min_event_ndims)) xs = move_dimensions(xs, source=xs_step_axes, dest=0) # Evaluate the initial state. y0 = self.initial_bijector.forward(x0) f0 = fn(self.initial_bijector, x0, y0) # Evaluate the rest of the chain. def loop_body(previous_y_and_result, idx): previous_y, _ = previous_y_and_result bij = self.bijector_fn(self.chain.transition_fn(idx, previous_y)) x_i = tf.nest.map_structure(lambda x: x[idx - 1], xs) y_i = bij.forward(x_i) f_i = fn(bij, x_i, y_i) return (y_i, tf.nest.map_structure(lambda a, b: tf.cast(a, b.dtype), f_i, f0)) _, fs = tf.scan(loop_body, elems=tf.range(1, self.chain.num_steps), initializer=(y0, f0)) return concat_initial(f0, fs)
def _inner_apply(x1, x2): order = ps.shape(self.amplitudes)[-1] def scan_fn(esp, i): s = self.kernel[..., i].apply( x1[..., i][..., tf.newaxis], x2[..., i][..., tf.newaxis], example_ndims=example_ndims) next_esp = esp[..., 1:] + s[..., tf.newaxis] * esp[..., :-1] # Add the zero-th polynomial. next_esp = tf.concat( [tf.ones_like(esp[..., 0][..., tf.newaxis]), next_esp], axis=-1) return next_esp batch_shape = ps.broadcast_shape( ps.shape(x1)[:-self.kernel.feature_ndims], ps.shape(x2)[:-self.kernel.feature_ndims]) batch_shape = ps.broadcast_shape( batch_shape, ps.concat([ self.batch_shape_tensor(), [1] * example_ndims], axis=0)) initializer = tf.concat( [tf.ones(ps.concat([batch_shape, [1]], axis=0), dtype=self.dtype), tf.zeros(ps.concat([batch_shape, [order]], axis=0), dtype=self.dtype)], axis=-1) esps = tf.scan( scan_fn, elems=ps.range(0, ps.shape(x1)[-1], dtype=tf.int32), parallel_iterations=32, initializer=initializer)[-1, ..., 1:] amplitudes = util.pad_shape_with_ones( self.amplitudes, ndims=example_ndims, start=-2) return tf.reduce_sum(esps * tf.math.square(amplitudes), axis=-1)
def fit_with_gibbs_sampling(model, observed_time_series, num_chains=(), num_results=2000, num_warmup_steps=200, initial_state=None, seed=None): """Fits parameters for an STS model using Gibbs sampling. Args: model: A `tfp.sts.StructuralTimeSeries` model instance return by `build_model_for_gibbs_fitting`. observed_time_series: `float` `Tensor` of shape [..., T, 1]` (omitting the trailing unit dimension is also supported when `T > 1`), specifying an observed time series. May optionally be an instance of `tfp.sts.MaskedTimeSeries`, which includes a mask `Tensor` to specify timesteps with missing observations. num_chains: Optional int to indicate the number of parallel MCMC chains. Default to an empty tuple to sample a single chain. num_results: Optional int to indicate number of MCMC samples. num_warmup_steps: Optional int to indicate number of MCMC samples. initial_state: A `GibbsSamplerState` structure of the initial states of the MCMC chains. seed: Optional `Python` `int` seed controlling the sampled values. Returns: model: A `GibbsSamplerState` structure of posterior samples. """ if not hasattr(model, 'supports_gibbs_sampling'): raise ValueError( 'This STS model does not support Gibbs sampling. Models ' 'for Gibbs sampling must be created using the ' 'method `build_model_for_gibbs_fitting`.') if not tf.nest.is_nested(num_chains): num_chains = [num_chains] [observed_time_series, is_missing] = sts_util.canonicalize_observed_time_series_with_mask( observed_time_series) dtype = observed_time_series.dtype # The canonicalized time series always has trailing dimension `1`, # because although LinearGaussianSSMs support vector observations, STS models # describe scalar time series only. For our purposes it'll be cleaner to # remove this dimension. observed_time_series = observed_time_series[..., 0] batch_shape = prefer_static.concat( [num_chains, prefer_static.shape(observed_time_series)[:-1]], axis=-1) level_slope_shape = prefer_static.concat( [num_chains, prefer_static.shape(observed_time_series)], axis=-1) # Treat a LocalLevel model as the special case of LocalLinearTrend where # the slope_scale is always zero. initial_slope_scale = 0. initial_slope = 0. if isinstance(model.components[0], sts.LocalLinearTrend): initial_slope_scale = 1. * tf.ones(batch_shape, dtype=dtype) initial_slope = tf.zeros(level_slope_shape, dtype=dtype) if initial_state is None: initial_state = GibbsSamplerState( observation_noise_scale=tf.ones(batch_shape, dtype=dtype), level_scale=tf.ones(batch_shape, dtype=dtype), slope_scale=initial_slope_scale, weights=tf.zeros(prefer_static.concat( [batch_shape, _get_design_matrix(model).shape[-1:]], axis=0), dtype=dtype), level=tf.zeros(level_slope_shape, dtype=dtype), slope=initial_slope, seed=None) # Set below. if isinstance(seed, six.integer_types): tf.random.set_seed(seed) # Always use the passed-in `seed` arg, ignoring any seed in the initial state. initial_state = initial_state._replace( seed=samplers.sanitize_seed(seed, salt='initial_GibbsSamplerState')) sampler_loop_body = _build_sampler_loop_body(model, observed_time_series, is_missing) samples = tf.scan(sampler_loop_body, np.arange(num_warmup_steps + num_results), initial_state) return tf.nest.map_structure(lambda x: x[num_warmup_steps:], samples)
def segment_cumsum(x, segment_ids, exclusive=False, dtype=None, name=None): """Computes cumulative sum of elements in a segment. For a complete description of segment_* ops see documentation of `tf.segment_sum`. This op extends the `tf.math.cumsum` functionality to segmented inputs. The behaviour of this op is the same as that of the op `tf.math.cumsum` within each segment. The result is effectively a concatenation of the results of `tf.math.cumsum` applied to each segment with the same interpretation for the argument `exclusive`. ## Example ```python x = tf.constant([2, 5, 1, 7, 9] + [32, 10, 12, 3] + [4, 8, 5]) segments = tf.constant([0, 0, 0, 0, 0] + [1, 1, 1, 1] + [2, 2, 2]) # Inclusive cumulative sum. # Expected result: [2, 7, 8, 15, 24, 32, 42, 54, 57, 4, 12, 17] cumsum1 = segment_cumsum( x, segment_ids=segments, exclusive=False) # Exclusive cumsum. # Expected result: [0, 2, 7, 8, 15, 0, 32, 42, 54, 0, 4, 12] cumsum2 = segment_cumsum( x, segment_ids=segments, exclusive=True) ``` Args: x: A rank 1 `Tensor` of any dtype for which arithmetic operations are permitted. segment_ids: A `Tensor`. Must be one of the following types: int32, int64. A 1-D tensor whose size is equal to the size of `x`. Values should be sorted and can be repeated. Values must range from `0` to `num segments - 1`. exclusive: Python bool. See description above. Default value: False dtype: Optional `tf.Dtype`. If supplied, the dtype for `x` to use when converting to `Tensor`. Default value: None which maps to the default dtype inferred by TF. name: Python `str` name prefixed to Ops created by this class. Default value: None which is mapped to the default name 'segment_cumsum'. Returns: cumsums: A `Tensor` of the same dtype as `x`. Assuming that each segment is of length greater than or equal to order, if `exclusive` is True, then the size is `n-order*k` where `n` is the size of x, `k` is the number of different segment ids supplied if `segment_ids` is not None or 1 if `segment_ids` is None. If any of the segments is of length less than the order, then the size is: `n-sum(min(order, length(segment_j)), j)` where the sum is over segments. If `exclusive` is False, then the size is `n`. """ with tf.compat.v1.name_scope(name, default_name='segment_cumsum', values=[x]): x = tf.convert_to_tensor(x, dtype=dtype) raw_cumsum = tf.math.cumsum(x, exclusive=exclusive) if segment_ids is None: return raw_cumsum # It is quite tedious to do a vectorized version without a while loop so # we skip that for now. # TODO(b/137940928): Replace these ops with more efficient C++ kernels. def scanner(accumulators, args): cumsum, prev_segment, prev_value = accumulators value, segment = args if exclusive: initial_value, inc_value = tf.zeros_like( value), cumsum + prev_value else: initial_value, inc_value = value, cumsum + value next_cumsum = tf.where(tf.equal(prev_segment, segment), inc_value, initial_value) return next_cumsum, segment, value return tf.scan(scanner, (x, segment_ids), initializer=(tf.zeros_like(x[0]), tf.zeros_like(segment_ids[0]) - 1, tf.zeros_like(x[0])))[0]
def estimate_parameters(self, observations, num_iterations, num_particles, initial_perturbation_scale, cooling_schedule, seed=None, name=None, **kwargs): """Runs multiple iterations of filtering following a cooling schedule. Args: observations: observed `Tensor` value(s) on which to condition the parameter estimate. num_iterations: `int `Tensor` number of filtering iterations to run. num_particles: scalar int `Tensor` number of particles to use. initial_perturbation_scale: scalar float `Tensor`, or any structure of float `Tensor`s broadcasting to the same shape as the (unconstrained) parameters, specifying the scale (standard deviation) of Gaussian perturbations to each parameter at the first timestep. cooling_schedule: callable with signature `cooling_factor = cooling_schedule(iteration)` for `iteration` in `[0, ..., num_iterations - 1]`. The filter is invoked with perturbations of scale `initial_perturbation_scale * cooling_schedule(iteration)`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. name: `str` name for ops constructed by this method. **kwargs: additional keyword arguments passed to `tfp.experimental.mcmc.infer_trajectories`. Returns: final_parameter_particles: structure of `Tensor`s matching `self.parameter_prior`, each with batch shape `[num_iterations, num_particles]`. These are the populations of particles representing the parameter estimate after each iteration of filtering. """ seed = SeedStream(seed, 'iterated_filter_estimate_parameters') with self._name_scope(name or 'estimate_parameters'): initial_perturbation_scale = tf.convert_to_tensor( initial_perturbation_scale, name='initial_perturbation_scale') # Get initial parameter particles from the first filtering iteration. initial_unconstrained_parameters = self.one_step( observations=observations, num_particles=num_particles, perturbation_scale=initial_perturbation_scale, seed=seed, **kwargs) # Run the remaining iterations and accumulate the results. @tf.function(autograph=False) def loop_body(unconstrained_parameters, cooling_fraction): return self.one_step( observations=observations, num_particles=num_particles, perturbation_scale=tf.nest.map_structure( lambda s: cooling_fraction * s, initial_perturbation_scale), initial_unconstrained_parameters=unconstrained_parameters, seed=seed, **kwargs) estimated_unconstrained_parameters = tf.scan( fn=loop_body, elems=cooling_schedule(tf.range(1, num_iterations)), initializer=initial_unconstrained_parameters) return self.parameter_constraining_bijector.forward( estimated_unconstrained_parameters)
def minimize(loss_fn, num_steps, optimizer, trainable_variables=None, trace_fn=_trace_loss, name='minimize'): """Minimize a loss function using a provided optimizer. Args: loss_fn: Python callable with signature `loss = loss_fn()`, where `loss` is a `Tensor` loss to be minimized. num_steps: Python `int` number of steps to run the optimizer. optimizer: Optimizer instance to use. This may be a TF1-style `tf.train.Optimizer`, TF2-style `tf.optimizers.Optimizer`, or any Python object that implements `optimizer.apply_gradients(grads_and_vars)`. trainable_variables: list of `tf.Variable` instances to optimize with respect to. If `None`, defaults to the set of all variables accessed during the execution of `loss_fn()`. Default value: `None`. trace_fn: Python callable with signature `state = trace_fn( loss, grads, variables)`, where `state` may be a `Tensor` or nested structure of `Tensor`s. The state values are accumulated (by `tf.scan`) and returned. The default `trace_fn` simply returns the loss, but in general can depend on the gradients and variables (if `trainable_variables` is not `None` then `variables==trainable_variables`; otherwise it is the list of all variables accessed during execution of `loss_fn()`), as well as any other quantities captured in the closure of `trace_fn`, for example, statistics of a variational distribution. Default value: `lambda loss, grads, variables: loss`. name: Python `str` name prefixed to ops created by this function. Default value: 'minimize'. Returns: trace: `Tensor` or nested structure of `Tensor`s, according to the return type of `trace_fn`. Each `Tensor` has an added leading dimension of size `num_steps`, packing the trajectory of the result over the course of the optimization. ### Examples To minimize the scalar function `(x - 5)**2`: ```python x = tf.Variable(0.) loss_fn = lambda: (x - 5.)**2 losses = tfp.math.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(learning_rate=0.1)) # In TF2/eager mode, the optimization runs immediately. print("optimized value is {} with loss {}".format(x, losses[-1])) ``` In graph mode (e.g., inside of `tf.function` wrapping), retrieving any Tensor that depends on the minimization op will trigger the optimization: ```python with tf.control_dependencies([losses]): optimized_x = tf.identity(x) # Use a dummy op to attach the dependency. ``` In some cases, we may want to track additional context inside the optimization. We can do this by defining a custom `trace_fn`. Note that the `trace_fn` is passed the loss and gradients, but it may also report the values of trainable variables or other derived quantities by capturing them in its closure. For example, we can capture `x` and track its value over the optimization: ```python # `x` is the tf.Variable instance defined above. trace_fn = lambda loss, grads, variables: {'loss': loss, 'x': x} trace = tfp.vi.minimize(loss_fn, num_steps=100, optimizer=tf.optimizers.Adam(0.1), trace_fn=trace_fn) print(trace['loss'].shape, # => [100] trace['x'].shape) # => [100] ``` """ @tf.function(autograph=False) def train_loop_body(old_result, step): # pylint: disable=unused-argument """Run a single optimization step.""" with tf.GradientTape( watch_accessed_variables=trainable_variables is None) as tape: for v in trainable_variables or []: tape.watch(v) loss = loss_fn() watched_variables = tape.watched_variables() grads = tape.gradient(loss, watched_variables) train_op = optimizer.apply_gradients(zip(grads, watched_variables)) with tf.control_dependencies([train_op]): state = trace_fn(tf.identity(loss), [tf.identity(g) for g in grads], [tf.identity(v) for v in watched_variables]) return state with tf.name_scope(name) as name: # Compute the shape of the trace without executing the graph, if possible. concrete_loop_body = train_loop_body.get_concrete_function( tf.TensorSpec([]), tf.TensorSpec([])) # Inputs ignored. if all([ tensorshape_util.is_fully_defined(shape) for shape in tf.nest.flatten(concrete_loop_body.output_shapes) ]): state_initializer = tf.nest.map_structure( lambda shape, dtype: tf.zeros(shape, dtype=dtype), concrete_loop_body.output_shapes, concrete_loop_body.output_dtypes) initial_trace_step = None else: state_initializer = concrete_loop_body( tf.convert_to_tensor(0.), tf.convert_to_tensor(0.)) # Inputs ignored. num_steps = num_steps - 1 initial_trace_step = state_initializer # TODO(b/136103064): Rewrite as explicit `while_loop` to support custom # convergence criteria and Tensor-valued `num_steps`, and avoid # re-tracing the train loop body. trace = tf.scan(train_loop_body, elems=np.arange(num_steps), initializer=state_initializer) if initial_trace_step is not None: trace = tf.nest.map_structure( lambda a, b: tf.concat([a[tf.newaxis, ...], b], axis=0), initial_trace_step, trace) return trace
def test_scan_with_initializer(self): elems = np.arange(5).astype(np.int32) self.assertAllEqual( self.evaluate(tf.scan(lambda x, y: x + y, elems, initializer=7)), nptf.scan(lambda x, y: x + y, elems, initializer=7))