def trace(state, fn, num_steps, parallel_iterations=10): """TF implementation of `trace` operator, without the calling convention.""" if tf.config.experimental_functions_run_eagerly() or tf.executing_eagerly( ): state, first_untraced, first_traced = fn(state) arrays = tf.nest.map_structure( lambda v: tf.TensorArray( # pylint: disable=g-long-lambda v.dtype, size=num_steps, element_shape=v.shape).write(0, v), first_traced) start_idx = 1 else: # We need the shapes and dtypes of the outputs of `fn` function to create # the `TensorArray`s etc., we can get it by pre-compiling the wrapper # function. input_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state) fn, (_, untraced_spec, traced_spec) = _eval_shape(fn, input_spec) arrays = tf.nest.map_structure( lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda spec.dtype, size=num_steps, element_shape=spec.shape), traced_spec) first_untraced = tf.nest.map_structure( lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec) start_idx = 0 def body(i, state, _, arrays): state, untraced, traced = fn(state) arrays = tf.nest.map_structure(lambda a, e: a.write(i, e), arrays, traced) return i + 1, state, untraced, arrays def cond(i, *_): return i < num_steps _, state, untraced, arrays = tf.while_loop( cond=cond, body=body, loop_vars=(start_idx, state, first_untraced, arrays), parallel_iterations=parallel_iterations, ) traced = tf.nest.map_structure(lambda a: a.stack(), arrays) static_length = tf.get_static_value(num_steps) def _merge_static_length(x): x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:])) return x traced = tf.nest.map_structure(_merge_static_length, traced) return state, untraced, traced
def _initialize_accumulated_quantities(observations, num_timesteps): """Initialize arrays passed through the filter loop.""" initial_arrays = [ tf.nest.map_structure( lambda x: tf.TensorArray(dtype=x.dtype, size=num_timesteps), observations) for _ in range(7) ] initial_arrays.append( tf.nest.map_structure( lambda _: tf.TensorArray(dtype=tf.int32, size=num_timesteps), observations)) return KalmanFilterState(*initial_arrays)
def _initialize_loop_variables(initial_step_results, num_timesteps, trace_fn, step_indices_to_trace): """Initialize arrays and other quantities passed through the filter loop.""" # Create arrays to store traced values (particles, likelihoods, etc). num_steps_to_trace = (num_timesteps if step_indices_to_trace is None else ps.size0(step_indices_to_trace)) traced_results = trace_fn(initial_step_results) trace_arrays = tf.nest.map_structure( lambda x: tf.TensorArray(dtype=x.dtype, size=num_steps_to_trace), traced_results) # If we are supposed to trace at step 0, write the traced values. num_steps_traced, trace_arrays = ps.cond( (True if step_indices_to_trace is None else ps.equal( step_indices_to_trace[0], 0)), lambda: ( 1, # pylint: disable=g-long-lambda tf.nest.map_structure(lambda ta, x: ta.write(0, x), trace_arrays, traced_results)), lambda: (0, trace_arrays)) return ParticleFilterLoopVariables( step=1, previous_step_results=initial_step_results, accumulated_traced_results=trace_arrays, num_steps_traced=num_steps_traced)
def call(self, inputs): samples = tf.TensorArray(dtype=tf.float32, size=tf.shape(inputs)[0]) i = 0 for sample in inputs: samples = samples.write(i, tf.square(sample)) i += 1 return samples.stack()
def _sample(dim, drift_fn, volatility_fn, times, time_step, keep_mask, times_shape, num_samples, initial_state, random_type, seed, swap_memory, skip, dtype): """Returns a sample of paths from the process using Euler method.""" dt = times[1:] - times[:-1] sqrt_dt = tf.sqrt(dt) current_state = initial_state + tf.zeros([num_samples, dim], dtype=initial_state.dtype) if dt.shape.is_fully_defined(): steps_num = dt.shape.as_list()[-1] else: steps_num = tf.shape(dt)[-1] # TODO(b/148133811): Re-enable Sobol test when TF 2.2 is released. if random_type == random.RandomType.SOBOL: raise ValueError( 'Sobol sequence for Euler sampling is temporarily ' 'unsupported when `time_step` or `times` have a ' 'non-constant value') # In order to use low-discrepancy random_type we need to generate the sequence # of independent random normals upfront. if random_type in (random.RandomType.SOBOL, random.RandomType.HALTON, random.RandomType.HALTON_RANDOMIZED): normal_draws = utils.generate_mc_normal_draws( num_normal_draws=dim, num_time_steps=steps_num, num_sample_paths=num_samples, random_type=random_type, dtype=dtype, seed=seed, skip=skip) wiener_mean = None else: # If pseudo or anthithetic sampling is used, proceed with random sampling # at each step. wiener_mean = tf.zeros((dim, ), dtype=dtype, name='wiener_mean') normal_draws = None cond_fn = lambda i, *args: i < steps_num # Maximum number iterations is passed to the while loop below. It improves # performance of the while loop on a GPU and is needed for XLA-compilation # comptatiblity. def step_fn(i, written_count, current_state, result): return _euler_step(i, written_count, current_state, result, drift_fn, volatility_fn, wiener_mean, num_samples, times, dt, sqrt_dt, keep_mask, random_type, seed, normal_draws) maximum_iterations = (tf.cast(1. / time_step, dtype=tf.int32) + tf.size(times)) result = tf.TensorArray(dtype=dtype, size=times_shape[-1]) _, _, _, result = tf.while_loop(cond_fn, step_fn, (0, 0, current_state, result), maximum_iterations=maximum_iterations, swap_memory=swap_memory) result = tf.transpose(result.stack(), (1, 0, 2)) # Shape of `rate_paths` is dynamic in `times` dimension because of # `TensorArray`. In order to make the shape static, use `set_shape` method. # TODO(b/148854825): Consider removing TensorArray to make all shapes static. result.set_shape(current_state.shape[:1] + times_shape + current_state.shape[-1:]) return result
def SanitizedAutoCorrelationMean(x, axis, reduce_axis, max_lags=None, **kwargs): shape_arr = np.array(list(x.shape)) axes = list(sorted(set(range(len(shape_arr))) - set([reduce_axis]))) mean_shape = shape_arr[axes] if max_lags is not None: mean_shape[axis] = max_lags + 1 mean_state = fun_mc.running_mean_init(mean_shape, x.dtype) new_order = list(range(len(shape_arr))) new_order[0] = new_order[reduce_axis] new_order[reduce_axis] = 0 x = tf.transpose(x, new_order) x_arr = tf.TensorArray(x.dtype, x.shape[0]).unstack(x) mean_state, _ = fun_mc.trace( state=mean_state, fn=lambda state: fun_mc.running_mean_step( # pylint: disable=g-long-lambda state, SanitizedAutoCorrelation(x_arr.read(state.num_points), axis, max_lags=max_lags, **kwargs)), num_steps=x.shape[0], trace_fn=lambda *_: ()) return mean_state.mean
def _initialize_loop_variables(initial_step_results, num_steps_state_history_to_pass, num_timesteps): """Initialize arrays and other quantities passed through the filter loop.""" # Create arrays to store particles, indices, and likelihoods, and write # their initial values. step_results_arrays = tf.nest.map_structure( lambda x: tf.TensorArray(dtype=x.dtype, size=num_timesteps).write( 0, x), initial_step_results) # Because `while_loop` requires Tensor values, we'll represent the lack of # state history by a static-shape empty Tensor. # This can be detected elsewhere by branching on # `tf.is_tensor(state_history) and state_history.shape[0] == 0`. state_history = tf.zeros([0]) if num_steps_state_history_to_pass: # Repeat the initial state, so that `state_history` always has length # `num_steps_state_history_to_pass`. state_history = tf.nest.map_structure( lambda x: tf.broadcast_to( # pylint: disable=g-long-lambda x[tf.newaxis, ...], prefer_static.concat([[num_steps_state_history_to_pass], prefer_static.shape(x)], axis=0)), initial_step_results.particles) return ParticleFilterLoopVariables( step=1, previous_step_results=initial_step_results, accumulated_step_results=step_results_arrays, state_history=state_history)
def nested_for_loops(m): l = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=()) for i in m: s = 0 for j in i: s = s * 10 + j l = l.write(l.size(), s) return l.stack()
def _map_body(trace_state): if not tf.is_tensor(trace_state): trace_state = tf.convert_to_tensor(trace_state) return tf.TensorArray(dtype=trace_state.dtype, size=size, dynamic_size=dynamic_size, element_shape=trace_state.shape, clear_after_read=False)
def _initialize_arrays(initial_values, num_steps): """Construct a structure of `TraceArray`s from initial values.""" trace_arrays = tf.nest.map_structure( lambda t: tf.TensorArray( # pylint: disable=g-long-lambda dtype=t.dtype, size=num_steps, # Initial size. clear_after_read=False, # Allow reading->tiling final value. element_shape=t.shape), initial_values) return tf.nest.map_structure(lambda ta, t: ta.write(0, t), trace_arrays, initial_values)
def nested_while_loops(n1, n2): i = 0 l = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=()) while i < n1: j = 0 s = 0 while j < n2: s = s * 10 + i * j j += 1 l = l.write(i, s) i += 1 return l.stack()
def _init(shape_and_dtype): if USE_TENSORARRAY: return [ tf.TensorArray(dtype=d, # pylint: disable=g-complex-comprehension size=self.max_tree_depth + 1, element_shape=s, clear_after_read=False) for (s, d) in shape_and_dtype] else: return [ tf.zeros( # pylint: disable=g-complex-comprehension tf.TensorShape([self.max_tree_depth + 1]).concatenate(s), dtype=d) for (s, d) in shape_and_dtype]
def _initialize_arrays(initial_values, num_steps, truncate_at_convergence): """Construct a structure of `TraceArray`s from initial values.""" num_steps_ = tf.get_static_value(tf.convert_to_tensor(num_steps)) size_is_dynamic = (num_steps_ is None or truncate_at_convergence) trace_arrays = tf.nest.map_structure( lambda t: tf.TensorArray( # pylint: disable=g-long-lambda dtype=t.dtype, size=1 if size_is_dynamic else num_steps_, # Initial size. dynamic_size=size_is_dynamic, clear_after_read=False, # Allow reading->tiling final value. element_shape=t.shape), initial_values) return tf.nest.map_structure(lambda ta, t: ta.write(0, t), trace_arrays, initial_values)
def update_backward_differences(backward_differences, next_backward_difference, next_state_vec, order): """Returns the backward differences for the next time.""" backward_differences_array = tf.TensorArray( backward_differences.dtype, size=MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.shape).unstack( backward_differences) new_backward_differences_array = tf.TensorArray( backward_differences.dtype, size=MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.shape) new_backward_differences_array = new_backward_differences_array.write( order + 2, next_backward_difference - backward_differences_array.read(order + 1)) new_backward_differences_array = new_backward_differences_array.write( order + 1, next_backward_difference) def body(k, new_backward_differences_array_): new_backward_differences_array_k = ( new_backward_differences_array_.read(k + 1) + backward_differences_array.read(k)) new_backward_differences_array_ = new_backward_differences_array_.write( k, new_backward_differences_array_k) return k - 1, new_backward_differences_array_ _, new_backward_differences_array = tf.while_loop( lambda k, new_backward_differences_array: k > 0, body, [order, new_backward_differences_array]) new_backward_differences_array = new_backward_differences_array.write( 0, next_state_vec) new_backward_differences = new_backward_differences_array.stack() tensorshape_util.set_shape(new_backward_differences, tf.TensorShape([MAX_ORDER + 3, None])) return new_backward_differences
def _initialize_accumulated_quantities( initial_kalman_filter_state, observations, num_timesteps, ): """Initialize quantities to accumulate, specifying dtype/shape.""" initial_arrays = [] for x in initial_kalman_filter_state: # pylint: disable=cell-var-from-loop initial_arrays.append( tf.nest.map_structure( lambda _: tf.TensorArray( # pylint: disable=g-long-lambda dtype=x.dtype, element_shape=x.shape, size=num_timesteps), observations)) # pylint: enable=cell-var-from-loop return linear_gaussian_ssm.KalmanFilterState(*initial_arrays)
def _sample_paths(self, times, grid_step, keep_mask, times_size, num_samples, initial_state, random_type, seed, swap_memory): """Returns a sample of paths from the process.""" dt = times[1:] - times[:-1] sqrt_dt = tf.sqrt(dt) current_state = initial_state + tf.zeros( [num_samples, self.dim()], dtype=initial_state.dtype) steps_num = tf.shape(dt)[-1] wiener_mean = tf.zeros((self.dim(), 1), dtype=self._dtype) cond_fn = lambda i, *args: i < steps_num def step_fn(i, written_count, current_state, result): """Performs one step of Euler scheme.""" current_time = times[i + 1] dw = random_ops.mv_normal_sample((num_samples, ), mean=wiener_mean, random_type=random_type, seed=seed) dw = dw * sqrt_dt[i] dt_inc = dt[i] * self.drift_fn()(current_time, current_state) # pylint: disable=not-callable dw_inc = tf.squeeze( tf.matmul(self.volatility_fn()(current_time, current_state), dw), -1) # pylint: disable=not-callable next_state = current_state + dt_inc + dw_inc # Keep only states for times, requested by user. result = tf.cond(keep_mask[i + 1], (lambda: result.write(written_count, next_state)), (lambda: result)) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_state, result) # Maximum number iterations is passed to the while loop below. It improves # performance of the while loop on a GPU and is needed for XLA-compilation # comptatiblity maximum_iterations = (tf.cast(1. / grid_step, dtype=tf.int32) + tf.size(times)) result = tf.TensorArray(dtype=self._dtype, size=times_size) _, _, _, result = tf.compat.v1.while_loop( cond_fn, step_fn, (0, 0, current_state, result), maximum_iterations=maximum_iterations, swap_memory=swap_memory) return tf.transpose(result.stack(), (1, 0, 2))
def trace_scan(loop_fn, initial_state, elems, trace_fn, parallel_iterations=10, name=None): """A simplified version of `tf.scan` that has configurable tracing. This function repeatedly calls `loop_fn(state, elem)`, where `state` is the `initial_state` during the first iteration, and the return value of `loop_fn` for every iteration thereafter. `elem` is a slice of `elements` along the first dimension, accessed in order. Additionally, it calls `trace_fn` on the return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are stacked and returned from this function, such that the first dimension of those `Tensor`s matches the size of `elems`. Args: loop_fn: A callable that takes in a `Tensor` or a nested collection of `Tensor`s with the same structure as `initial_state`, a slice of `elems` and returns the same structure as `initial_state`. initial_state: A `Tensor` or a nested collection of `Tensor`s passed to `loop_fn` in the first iteration. elems: A `Tensor` that is split along the first dimension and each element of which is passed to `loop_fn`. trace_fn: A callable that takes in the return value of `loop_fn` and returns a `Tensor` or a nested collection of `Tensor`s. parallel_iterations: Passed to the internal `tf.while_loop`. name: Name scope used in this function. Default: 'trace_scan'. Returns: final_state: The final return value of `loop_fn`. trace: The same structure as the return value of `trace_fn`, but with each `Tensor` being a stack of the corresponding `Tensors` in the return value of `trace_fn` for each slice of `elems`. """ with tf.name_scope(name or 'trace_scan'), tf1.variable_scope( tf1.get_variable_scope()) as vs: if vs.caching_device is None and not tf.executing_eagerly(): vs.set_caching_device(lambda op: op.device) initial_state = tf.nest.map_structure( lambda x: tf.convert_to_tensor(x, name='initial_state'), initial_state) elems = tf.convert_to_tensor(elems, name='elems') length = prefer_static.size0(elems) static_length = length if prefer_static.is_numpy(length) else None # This is an TensorArray in part because of XLA, which had trouble with # non-statically known indices. I.e. elems[i] errored, but # elems_array.read(i) worked. elems_array = tf.TensorArray(elems.dtype, size=length, element_shape=elems.shape[1:]) elems_array = elems_array.unstack(elems) trace_arrays = tf.nest.map_structure( lambda x: tf.TensorArray( x.dtype, size=length, element_shape=x.shape), trace_fn(initial_state)) def _body(i, state, trace_arrays): state = loop_fn(state, elems_array.read(i)) trace_arrays = tf.nest.pack_sequence_as(trace_arrays, [ a.write(i, v) for a, v in zip(tf.nest.flatten(trace_arrays), tf.nest.flatten(trace_fn(state))) ]) return i + 1, state, trace_arrays _, final_state, trace_arrays = tf.while_loop( cond=lambda i, *args: i < length, body=_body, loop_vars=(0, initial_state, trace_arrays), parallel_iterations=parallel_iterations) stacked_trace = tf.nest.map_structure(lambda x: x.stack(), trace_arrays) # Restore the static length if we know it. def _merge_static_length(x): x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:])) return x stacked_trace = tf.nest.map_structure(_merge_static_length, stacked_trace) return final_state, stacked_trace
def grad_fn(*dresults, **kwargs): """Adjoint sensitivity method to compute gradients.""" dresults = tf.nest.pack_sequence_as(results, dresults) dstates = dresults.states # The signature grad_fn(*dresults, variables=None) is not valid Python 2 # so use kwargs instead. variables = kwargs.pop('variables', []) assert not kwargs # This assert should never fail. # TODO(b/138304303): Support complex types. with tf.name_scope('{}Gradients'.format(self._name)): get_dtype = lambda x: x.dtype def error_if_complex(dtype): if dtype.is_complex: raise NotImplementedError( 'The adjoint sensitivity method does ' 'not support complex dtypes.') state_dtypes = tf.nest.map_structure( get_dtype, initial_state) tf.nest.map_structure(error_if_complex, state_dtypes) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # We add initial_time to ensure that we know where to stop. result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # First two components correspond to reverse and adjoint states. # the last component is adjoint state for variables. terminal_augmented_state = tuple([ rk_util.nest_constant(initial_state, 0.0), rk_util.nest_constant(initial_state, 0.0), tuple( rk_util.nest_constant(variable, 0.0) for variable in variables) ]) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) # TensorArray shape should not include time dimension, hence shape[1:] result_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(results.states) ] result_state_arrays = tf.nest.pack_sequence_as( results.states, result_state_arrays) dresult_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, element_shape=component.shape[1:]).unstack( component) for component in tf.nest.flatten(dstates) ] dresult_state_arrays = tf.nest.pack_sequence_as( results.states, dresult_state_arrays) def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 3 components `(state, adjoint_state, vars)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. vars: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables`. Adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _ = augmented_state with tf.GradientTape() as tape: tape.watch(variables) tape.watch(state) derivatives = ode_fn(time, state) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n( tf.nest.flatten(adjoint_dot_derivatives))) adjoint_ode, adjoint_variables_ode = tape.gradient( adjoint_dot_derivatives, (state, tuple(variables)), unconnected_gradients=tf.UnconnectedGradients.ZERO) return negative_derivatives, adjoint_ode, adjoint_variables_ode def reverse_to_result_time(n, augmented_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) _, adjoint_state, adjoint_variable_state = augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) initial_augmented_state = (initial_state, initial_adjoint_state, adjoint_variable_state) # TODO(b/138304303): Allow the user to specify the Hessian of # `ode_fn` so that we can get the Jacobian of the adjoint system. # TODO(b/143624114): Support higher order derivatives. augmented_results = self._solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return n - 1, result_state, status _, augmented_state, _ = tf.while_loop( lambda n, _, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (num_result_times - 1, terminal_augmented_state, 0), back_prop=False) _, adjoint_state, adjoint_variables = augmented_state return adjoint_state, list(adjoint_variables)
def diag_jacobian(xs, ys=None, sample_shape=None, fn=None, parallel_iterations=10, name=None): """Computes diagonal of the Jacobian matrix of `ys=fn(xs)` wrt `xs`. If `ys` is a tensor or a list of tensors of the form `(ys_1, .., ys_n)` and `xs` is of the form `(xs_1, .., xs_n)`, the function `jacobians_diag` computes the diagonal of the Jacobian matrix, i.e., the partial derivatives `(dys_1/dxs_1,.., dys_n/dxs_n`). For definition details, see https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant #### Example ##### Diagonal Hessian of the log-density of a 3D Gaussian distribution In this example we sample from a standard univariate normal distribution using MALA with `step_size` equal to 0.75. ```python import tensorflow as tf import tensorflow_probability as tfp import numpy as np tfd = tfp.distributions dtype = np.float32 with tf.Session(graph=tf.Graph()) as sess: true_mean = dtype([0, 0, 0]) true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]]) chol = tf.linalg.cholesky(true_cov) target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol) # Assume that the state is passed as a list of tensors `x` and `y`. # Then the target function is defined as follows: def target_fn(x, y): # Stack the input tensors together z = tf.concat([x, y], axis=-1) - true_mean return target.log_prob(z) sample_shape = [3, 5] state = [tf.ones(sample_shape + [2], dtype=dtype), tf.ones(sample_shape + [1], dtype=dtype)] fn_val, grads = tfp.math.value_and_gradient(target_fn, state) # We can either pass the `sample_shape` of the `state` or not, which impacts # computational speed of `diag_jacobian` _, diag_jacobian_shape_passed = diag_jacobian( xs=state, ys=grads, sample_shape=tf.shape(fn_val)) _, diag_jacobian_shape_none = diag_jacobian( xs=state, ys=grads) diag_jacobian_shape_passed_ = sess.run(diag_jacobian_shape_passed) diag_jacobian_shape_none_ = sess.run(diag_jacobian_shape_none) print('hessian computed through `diag_jacobian`, sample_shape passed: ', np.concatenate(diag_jacobian_shape_passed_, -1)) print('hessian computed through `diag_jacobian`, sample_shape skipped', np.concatenate(diag_jacobian_shape_none_, -1)) ``` Args: xs: `Tensor` or a python `list` of `Tensors` of real-like dtypes and shapes `sample_shape` + `event_shape_i`, where `event_shape_i` can be different for different tensors. ys: `Tensor` or a python `list` of `Tensors` of the same dtype as `xs`. Must broadcast with the shape of `xs`. Can be omitted if `fn` is provided. sample_shape: A common `sample_shape` of the input tensors of `xs`. If not, provided, assumed to be `[1]`, which may result in a slow performance of `jacobians_diag`. fn: Python callable that takes `xs` as an argument (or `*xs`, if it is a list) and returns `ys`. Might be skipped if `ys` is provided and `tf.enable_eager_execution()` is disabled. parallel_iterations: `int` that specifies the allowed number of coordinates of the input tensor `xs`, for which the partial derivatives `dys_i/dxs_i` can be computed in parallel. name: Python `str` name prefixed to `Ops` created by this function. Default value: `None` (i.e., "diag_jacobian"). Returns: ys: a list, which coincides with the input `ys`, when provided. If the input `ys` is None, `fn(*xs)` gets computed and returned as a list. jacobians_diag_res: a `Tensor` or a Python list of `Tensor`s of the same dtypes and shapes as the input `xs`. This is the diagonal of the Jacobian of ys wrt xs. Raises: ValueError: if lists `xs` and `ys` have different length or both `ys` and `fn` are `None`, or `fn` is None in the eager execution mode. """ with tf.name_scope(name or 'jacobians_diag'): if sample_shape is None: sample_shape = [1] # Output Jacobian diagonal jacobians_diag_res = [] # Convert input `xs` to a list xs = list(xs) if _is_list_like(xs) else [xs] xs = [tf.convert_to_tensor(x) for x in xs] if not tf.executing_eagerly(): if ys is None: if fn is None: raise ValueError('Both `ys` and `fn` can not be `None`') else: ys = fn(*xs) # Convert ys to a list ys = list(ys) if _is_list_like(ys) else [ys] if len(xs) != len(ys): raise ValueError('`xs` and `ys` should have the same length') for y, x in zip(ys, xs): # Broadcast `y` to the shape of `x`. y_ = y + tf.zeros_like(x) # Change `event_shape` to one-dimension y_ = tf.reshape(y, tf.concat([sample_shape, [-1]], -1)) # Declare an iterator and tensor array loop variables for the gradients. n = tf.size(x) / tf.cast(tf.reduce_prod(sample_shape), dtype=tf.int32) n = tf.cast(n, dtype=tf.int32) loop_vars = [0, tf.TensorArray(x.dtype, n)] def loop_body(j): """Loop function to compute gradients of the each direction.""" # Gradient along direction `j`. res = tf.gradients(ys=y_[..., j], xs=x)[0] # pylint: disable=cell-var-from-loop if res is None: # Return zero, if the gradient is `None`. res = tf.zeros(tf.concat([sample_shape, [1]], -1), dtype=x.dtype) # pylint: disable=cell-var-from-loop else: # Reshape `event_shape` to 1D res = tf.reshape(res, tf.concat([sample_shape, [-1]], -1)) # Add artificial dimension for the case of zero shape input tensor res = res[tf.newaxis, ..., j] return res # pylint: disable=cell-var-from-loop # Iterate over all elements of the gradient and compute second order # derivatives. _, jacobian_diag_res = tf.while_loop( cond=lambda j, _: j < n, # pylint: disable=cell-var-from-loop body=lambda j, result: (j + 1, result.write(j, loop_body(j))), loop_vars=loop_vars, parallel_iterations=parallel_iterations) shape_x = ps.shape(x) # Stack gradients together and move flattened `event_shape` to the # zero position reshaped_jacobian_diag = tf.transpose( a=jacobian_diag_res.stack()) # Reshape to the original tensor reshaped_jacobian_diag = tf.reshape(reshaped_jacobian_diag, shape_x) jacobians_diag_res.append(reshaped_jacobian_diag) else: if fn is None: raise ValueError( '`fn` can not be `None` when eager execution is ' 'enabled') if ys is None: ys = fn(*xs) def fn_slice(i, j): """Broadcast y[i], flatten event shape of y[i], return y[i][..., j].""" def fn_broadcast(*state): res = fn(*state) res = list(res) if _is_list_like(res) else [res] if len(res) != len(state): res *= len(state) res = [ tf.reshape(r + tf.zeros_like(s), ps.concat([sample_shape, [-1]], -1)) for r, s in zip(res, state) ] return res # Expand dimensions before returning in order to support 0D input `xs` return lambda *state: tf.expand_dims( fn_broadcast(*state)[i], 0)[..., j] def make_loop_body(i, x): """Loop function to compute gradients of the each direction.""" def _fn(j, result): res = value_and_gradient(fn_slice(i, j), xs)[1][i] if res is None: res = tf.zeros(sample_shape, dtype=x.dtype) else: res = tf.reshape(res, ps.concat([sample_shape, [-1]], -1)) res = res[..., j] return j + 1, result.write(j, res) return _fn for i, x in enumerate(xs): # Declare an iterator and tensor array loop variables for the gradients. n = ps.size(x) / ps.cast(ps.reduce_prod(sample_shape), dtype=tf.int32) n = ps.cast(n, dtype=tf.int32) loop_vars = (0, tf.TensorArray(x.dtype, n, element_shape=sample_shape)) # Iterate over all elements of the gradient and compute second order # derivatives. _, jacobian_diag_res = tf.while_loop( cond=lambda j, _: j < n, body=make_loop_body(i, x), loop_vars=loop_vars, parallel_iterations=parallel_iterations) shape_x = ps.shape(x) # Stack gradients together and move flattened `event_shape` to the # zero position reshaped_jacobian_diag = tf.transpose( jacobian_diag_res.stack()) # Reshape to the original tensor reshaped_jacobian_diag = tf.reshape(reshaped_jacobian_diag, shape_x) jacobians_diag_res.append(reshaped_jacobian_diag) return ys, jacobians_diag_res
def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]: """Internal function to map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ keys = list(feature_lengths) partial = empty_example.copy() first_key, *_ = keys dynamic_batch_size = tf.shape(x[first_key])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[feature_lengths[k]]) outputs[k + "_positions"] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[feature_lengths[k]]) for i in tf.range(0, dynamic_batch_size): tf.autograph.experimental.set_loop_options(shape_invariants=[( partial, {k: tf.TensorShape([None]) for k in keys_etc} ), (outputs, {k: tf.TensorShape(None) for k in keys_etc})]) can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[:tf. reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), feature_lengths[k])) if not can_append: partial, outputs = _write_packed_example(partial, outputs) new_partial = {} for k in keys: new_seq = one_example[k][:feature_lengths[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + "_positions"] = tf.concat([ partial[k + "_positions"], tf.range(new_seq_len, dtype=tf.int32) ], 0) partial = new_partial partial, outputs = _write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + "_segment_ids"] = (tf.cumsum( tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32), axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)) return packed
def _solve( self, ode_fn, initial_time, initial_state, solution_times, jacobian_fn=None, jacobian_sparsity=None, batch_ndims=None, previous_solver_internal_state=None, ): # Static assertions del jacobian_fn, jacobian_sparsity # not used by DormandPrince if batch_ndims is not None and batch_ndims != 0: raise NotImplementedError( 'For homogeneous batching use `batch_ndims=0`.') solution_times_by_solver = isinstance(solution_times, base.ChosenBySolver) with tf.name_scope(self._name): # (2) Convert to tensors, determine dtypes. p = self._prepare_common_params(initial_state, initial_time) max_num_steps = self._max_num_steps max_ode_fn_evals = self._max_num_steps if max_num_steps is not None: max_num_steps = tf.convert_to_tensor(max_num_steps, dtype=tf.int32) max_ode_fn_evals = max_num_steps * self.ODE_FN_EVALS_PER_STEP step_size = tf.convert_to_tensor(self._step_size, dtype=p.real_dtype) rtol = tf.convert_to_tensor(tf.cast(self._rtol, p.real_dtype)) atol = tf.convert_to_tensor(tf.cast(self._atol, p.real_dtype)) # Use i(d)factor notation for increasing and decreasing factors. solver_internal_state = previous_solver_internal_state if solver_internal_state is None: solver_internal_state = self._initialize_solver_internal_state( ode_fn=ode_fn, initial_state=p.initial_state, initial_time=p.initial_time, ) num_solution_times = 0 if solution_times_by_solver: final_time = tf.cast(solution_times.final_time, p.real_dtype) times_array = tf.TensorArray(p.real_dtype, size=num_solution_times, dynamic_size=True, element_shape=tf.TensorShape([])) else: solution_times = tf.cast(solution_times, p.real_dtype) util.error_if_not_vector(solution_times, 'solution_times') num_solution_times = tf.size(solution_times) times_array = tf.TensorArray( p.real_dtype, size=num_solution_times, dynamic_size=False, element_shape=[]).unstack(solution_times) solutions_arrays = [ tf.TensorArray(dtype=component_dtype, size=num_solution_times, dynamic_size=solution_times_by_solver) for component_dtype in tf.nest.flatten(p.state_dtypes) ] solutions_arrays = tf.nest.pack_sequence_as( p.initial_state, solutions_arrays) rk_step = functools.partial(self._step, max_ode_fn_evals=max_ode_fn_evals, ode_fn=ode_fn, atol=atol, rtol=rtol, safety=safety, ifactor=ifactor, dfactor=dfactor) advance_to_solution_time = functools.partial( _advance_to_solution_time, times_array=solution_times, step_fn=rk_step, validate_args=self._validate_args) assert_ops = self._assert_ops( ode_fn=ode_fn, initial_time=p.initial_time, initial_state=p.initial_state, solution_times=solution_times, previous_solver_state=previous_solver_internal_state, rtol=rtol, atol=atol, first_step_size=step_size, safety_factor=safety, min_step_size_factor=ifactor, max_step_size_factor=dfactor, max_num_steps=max_num_steps, solution_times_by_solver=solution_times_by_solver) with tf.control_dependencies(assert_ops): ode_evals_by_now = 1 if self._validate_args else 0 ode_evals_by_now += 1 if solver_internal_state is None else 0 diagnostics = _DopriDiagnostics( num_ode_fn_evaluations=ode_evals_by_now, num_jacobian_evaluations=0, num_matrix_factorizations=0, status=0) if solution_times_by_solver: r = _dense_solutions_to_final_time( final_time=final_time, solver_state=solver_internal_state, diagnostics=diagnostics, step_fn=rk_step, ode_fn=ode_fn, times_array=times_array, solutions_arrays=solutions_arrays, validate_args=self._validate_args) solver_internal_state, diagnostics, times_array, solutions_arrays = r else: def iterate_cond(time_id, *_): return time_id < num_solution_times [ _, solver_internal_state, diagnostics, solutions_arrays ] = tf.while_loop(iterate_cond, advance_to_solution_time, [ 0, solver_internal_state, diagnostics, solutions_arrays ]) times = times_array.stack() stack_components = lambda x: x.stack() states = tf.nest.map_structure(stack_components, solutions_arrays) return base.Results( times=times, states=states, diagnostics=diagnostics, solver_internal_state=solver_internal_state)
def vjp_bwd(results_constants, dresults, variables=()): """Adjoint sensitivity method to compute gradients.""" results, constants = results_constants adjoint_solver = self._make_adjoint_solver_fn() dstates = dresults.states # TODO(b/138304303): Support complex types. with tf.name_scope('{}Gradients'.format(self._name)): get_dtype = lambda x: x.dtype def error_if_complex(dtype): if dtype_util.is_complex(dtype): raise NotImplementedError( 'The adjoint sensitivity method does ' 'not support complex dtypes.') state_dtypes = tf.nest.map_structure(get_dtype, initial_state) tf.nest.map_structure(error_if_complex, state_dtypes) common_state_dtype = dtype_util.common_dtype(initial_state) real_dtype = dtype_util.real_dtype(common_state_dtype) # We add initial_time to ensure that we know where to stop. result_times = tf.concat( [[tf.cast(initial_time, real_dtype)], results.times], 0) num_result_times = tf.size(result_times) # First two components correspond to reverse and adjoint states. # the last two component is adjoint state for variables and constants. terminal_augmented_state = tuple([ rk_util.nest_constant(initial_state, 0.0), rk_util.nest_constant(initial_state, 0.0), tuple( rk_util.nest_constant(variable, 0.0) for variable in variables), rk_util.nest_constant(constants, 0.0), ]) # The XLA compiler does not compile code which slices/indexes using # integer `Tensor`s. `TensorArray`s are used to get around this. result_time_array = tf.TensorArray( results.times.dtype, clear_after_read=False, size=num_result_times, element_shape=[]).unstack(result_times) # TensorArray shape should not include time dimension, hence shape[1:] result_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, clear_after_read=False, element_shape=component.shape[1:]).unstack(component) for component in tf.nest.flatten(results.states) ] result_state_arrays = tf.nest.pack_sequence_as( results.states, result_state_arrays) dresult_state_arrays = [ tf.TensorArray( # pylint: disable=g-complex-comprehension dtype=component.dtype, size=num_result_times - 1, clear_after_read=False, element_shape=component.shape[1:]).unstack(component) for component in tf.nest.flatten(dstates) ] dresult_state_arrays = tf.nest.pack_sequence_as( results.states, dresult_state_arrays) def augmented_ode_fn(backward_time, augmented_state): """Dynamics function for the augmented system. Describes a differential equation that evolves the augmented state backwards in time to compute gradients using the adjoint method. Augmented state consists of 4 components `(state, adjoint_state, vars, constants)` all evaluated at time `backward_time`: state: represents the solution of user provided `ode_fn`. The structure coincides with the `initial_state`. adjoint_state: represents the solution of the adjoint sensitivity differential equation as discussed below. Has the same structure and shape as `state`. variables: represent the solution of the adjoint equation for variable gradients. Represented as a `Tuple(Tensor, ...)` with as many tensors as there are `variables` variable outside this function. constants: represent the solution of the adjoint equation for constant gradients. Has the same structure and shape as `constants` variable outside this function. The adjoint sensitivity equation describes the gradient of the solution with respect to the value of the solution at a previous time t. Its dynamics are given by d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z) Which is computed as: d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)] d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] where in the last line we moved adj(t)_j under derivative by removing gradient from it. Adjoint equation for the gradient with respect to every `tf.Variable` and constant theta follows: d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta) = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)] Args: backward_time: Floating `Tensor` representing current time. augmented_state: `Tuple(state, adjoint_state, variable_grads)` Returns: negative_derivatives: Structure of `Tensor`s equal to backwards time derivative of the `state` componnent. adjoint_ode: Structure of `Tensor`s equal to backwards time derivative of the `adjoint_state` component. adjoint_variables_ode: Structure of `Tensor`s equal to backwards time derivative of the `vars` component. adjoint_constants_ode: Structure of `Tensor`s equal to backwards time derivative of the `constants` component. """ # The negative signs disappears after the change of variables. # The ODE solver cannot handle the case initial_time > final_time # and hence a change of variables backward_time = -time is used. time = -backward_time state, adjoint_state, _, _ = augmented_state # TODO(b/152464477): Doesn't work reliably in TF1. def grad_fn(state, variables, constants): del variables # We compute these gradients via the GradientTape # capturing them. derivatives = ode_fn(time, state, **constants) adjoint_no_grad = tf.nest.map_structure( tf.stop_gradient, adjoint_state) negative_derivatives = rk_util.weighted_sum( [-1.0], [derivatives]) def dot_prod(tensor_a, tensor_b): return tf.reduce_sum(tensor_a * tensor_b) # See docstring for details. adjoint_dot_derivatives = tf.nest.map_structure( dot_prod, adjoint_no_grad, derivatives) adjoint_dot_derivatives = tf.squeeze( tf.add_n(tf.nest.flatten(adjoint_dot_derivatives))) return adjoint_dot_derivatives, negative_derivatives values = (state, tuple(variables), constants) ((_, negative_derivatives), gradients) = tfp_gradient.value_and_gradient( grad_fn, values, has_aux=True, use_gradient_tape=True) (adjoint_ode, adjoint_variables_ode, adjoint_constants_ode) = tf.nest.map_structure( lambda v, g: tf.zeros_like(v) if g is None else g, values, tuple(gradients)) return (negative_derivatives, adjoint_ode, adjoint_variables_ode, adjoint_constants_ode) def make_augmented_state(n, prev_augmented_state): """Constructs the augmented state for step `n`.""" (_, adjoint_state, adjoint_variable_state, adjoint_constant_state) = prev_augmented_state initial_state = _read_solution_components( result_state_arrays, input_state_structure, n - 1, ) initial_adjoint = _read_solution_components( dresult_state_arrays, input_state_structure, n - 1, ) initial_adjoint_state = rk_util.weighted_sum( [1.0, 1.0], [adjoint_state, initial_adjoint]) augmented_state = ( initial_state, initial_adjoint_state, adjoint_variable_state, adjoint_constant_state, ) return augmented_state def reverse_to_result_time(n, augmented_state, solver_internal_state, _): """Integrates the augmented system backwards in time.""" lower_bound_of_integration = result_time_array.read(n) upper_bound_of_integration = result_time_array.read(n - 1) initial_augmented_state = make_augmented_state( n, augmented_state) # TODO(b/138304303): Allow the user to specify the Hessian of # `ode_fn` so that we can get the Jacobian of the adjoint system. # TODO(b/143624114): Support higher order derivatives. solver_internal_state = ( adjoint_solver. _adjust_solver_internal_state_for_state_jump( # pylint: disable=protected-access ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, previous_solver_internal_state= solver_internal_state, previous_state=augmented_state, )) augmented_results = adjoint_solver.solve( ode_fn=augmented_ode_fn, initial_time=-lower_bound_of_integration, initial_state=initial_augmented_state, solution_times=[-upper_bound_of_integration], batch_ndims=batch_ndims, previous_solver_internal_state=solver_internal_state, ) # Results added an extra time dim of size 1, squeeze it. select_result = lambda x: tf.squeeze(x, [0]) result_state = augmented_results.states result_state = tf.nest.map_structure( select_result, result_state) status = augmented_results.diagnostics.status return (n - 1, result_state, augmented_results.solver_internal_state, status) initial_n = num_result_times - 1 solver_internal_state = adjoint_solver._initialize_solver_internal_state( # pylint: disable=protected-access ode_fn=augmented_ode_fn, initial_time=result_time_array.read(initial_n), initial_state=make_augmented_state( initial_n, terminal_augmented_state), ) _, augmented_state, _, _ = tf.while_loop( lambda n, _as, _sis, status: (n >= 1) & tf.equal(status, 0), reverse_to_result_time, (initial_n, terminal_augmented_state, solver_internal_state, 0), back_prop=False, ) (_, adjoint_state, adjoint_variables, adjoint_constants) = augmented_state if variables: return (adjoint_state, adjoint_constants), list(adjoint_variables) else: return adjoint_state, adjoint_constants
def scan(f, init, xs, length=None, reverse=False): """Scan a function over leading array axes while carrying along state. See the docstring of `jax.lax.scan` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) for details. Args: f: a Python function to be scanned of type ``c -> a -> (c, b)``, meaning that ``f`` accepts two arguments where the first is a value of the loop carry and the second is a slice of ``xs`` along its leading axis, and that ``f`` returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output. Note that the input and output carry must have the same dtype. init: an initial loop carry value of type ``c``, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned by ``f``. xs: the value of type ``[a]`` over which to scan along the leading axis, where ``[a]`` can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes. length: optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in ``xs`` (but can be used to perform scans where no input ``xs`` are needed). reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both ``xs`` and in ``ys``. Returns: A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. """ init, xs = tf.nest.map_structure( lambda x: tf_np.asarray(x) if x is not None else None, (init, xs)) init, xs = _np_to_tf((init, xs)) def get_length(x): if x is None: return None if x.shape.rank == 0: raise ValueError( "Some array in `xs` doesn't have a leading dimension") return x.shape[0] lengths = tf.nest.flatten(tf.nest.map_structure(get_length, xs)) for l in lengths: if l is not None: if length is None: length = l elif length != l: raise ValueError( "There are two different leading-dimension lengths: " f"{length} and {l}") if length is None: raise ValueError( "Can't determine length. Please set the `length` argument.") xs_ta = tf.nest.map_structure( lambda t: ( tf.TensorArray(t.dtype, size=0, dynamic_size=True).unstack(t) # pylint: disable=g-long-lambda if t is not None else None), xs) def body(i, carry, ys_ta): if reverse: i_ = length - 1 - i else: i_ = i xs = tf.nest.map_structure( lambda x_ta: x_ta.read(i_) if x_ta is not None else None, xs_ta) carry, ys = _np_to_tf(f(*_tf_to_np((carry, xs)))) ys_ta = tf.nest.map_structure( lambda y_ta, y: (y_ta.write(i_, y) if y is not None else y_ta), ys_ta, ys) i = i + 1 return i, carry, ys_ta xs_spec = tf.nest.map_structure( lambda t: tf.TensorSpec(t.shape[1:], t.dtype) if t is not None else None, xs) _, ys_spec = eval_on_shapes(f)(init, xs_spec) # ys_ta can't contain None because tf.while_loop doesn't allow None in # loop_vars. ys_ta = tf.nest.map_structure( lambda y: tf.TensorArray( y.dtype if y is not None else tf.float32, size=0, # pylint: disable=g-long-lambda dynamic_size=True), ys_spec) _, carry, ys_ta = tf.while_loop(lambda i, *_: i < length, body, (0, init, ys_ta)) def _stack(a, spec): if spec is None: return None a = a.stack() a.set_shape((length, ) + a.shape[1:]) return a ys = tf.nest.map_structure(_stack, ys_ta, ys_spec) return _tf_to_np((carry, ys))
def _sample_multinomial_as_iterated_binomial(num_samples, num_classes, probs, num_trials, dtype, seed): """Sample a multinomial by drawing one binomial sample per class. The batch shape is given by broadcasting num_trials with remove_last_dimension(probs). The loop over binomial samples is a `tf.while_loop`, thus supporting a dynamic number of classes. Args: num_samples: Singleton integer Tensor: number of multinomial samples to draw. num_classes: Singleton integer Tensor: number of classes. probs: Floating Tensor with last dimension `num_classes`, of normalized probabilities per class. num_trials: Tensor of number of categorical trials each multinomial consists of. num_trials[..., tf.newaxis] must broadcast with probs. dtype: dtype at which to emit samples. seed: Random seed. Returns: samples: Tensor of given dtype and shape [num_samples] + batch_shape + [num_classes]. """ with tf.name_scope('draw_sample'): # `convert_to_tensor(num_classes) here to avoid unstacking inside # `split_seed`. We can't take advantage of the Python-list code path anyway # because the index at which we will take the seed is a Tensor. seeds = samplers.split_seed(seed, n=tf.convert_to_tensor(num_classes), salt='multinomial_draw_sample') def fn(i, num_trials, consumed_prob, accum): """Sample the counts for one class using binomial.""" probs_here = tf.gather(probs, i, axis=-1) binomial_probs = tf.clip_by_value( probs_here / (1. - consumed_prob), 0, 1) seed_here = tf.gather(seeds, i, axis=0) binom = binomial.Binomial(total_count=num_trials, probs=binomial_probs) # Not passing `num_samples` to `binom.sample`, as it's is already in # `num_trials.shape`. sample = binom.sample(seed=seed_here) accum = accum.write(i, tf.cast(sample, dtype=dtype)) return i + 1, num_trials - sample, consumed_prob + probs_here, accum num_trials = tf.cast(num_trials, probs.dtype) # Pre-broadcast with probs num_trials += tf.zeros_like(probs[..., 0]) # Pre-enlarge for different output samples num_trials = _replicate_along_left(num_trials, num_samples) i = tf.constant(0) consumed_prob = tf.zeros_like(probs[..., 0]) accum = tf.TensorArray(dtype, size=num_classes, element_shape=num_trials.shape) _, num_trials_left, _, accum = tf.while_loop( cond=lambda index, _0, _1, _2: tf.less(index, num_classes - 1), body=fn, loop_vars=(i, num_trials, consumed_prob, accum)) # Force the last iteration to put all the trials into the last bucket, # because probs[..., -1] / (1. - consumed_prob) might numerically not be 1. # Also saves one iteration around the while_loop and one run of the binomial # sampler. accum = accum.write(num_classes - 1, tf.cast(num_trials_left, dtype=dtype)) # This stop_gradient is necessary to prevent spurious zero gradients coming # from b/138796859, and a spurious gradient through num_trials_left. results = tf.stop_gradient(accum.stack()) return distribution_util.move_dimension(results, 0, -1)
def map_fn(x): """Internal function to flat_map over. Consumes a batch of input examples and produces a variable number of output examples. Args: x: a single example Returns: a tf.data.Dataset """ partial = empty_example.copy() i = tf.zeros([], dtype=tf.int32) dynamic_batch_size = tf.shape(x[keys[0]])[0] outputs = {} for k in keys: outputs[k] = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=[length[k]]) outputs[k + '_position'] = tf.TensorArray( tf.int32, size=0, dynamic_size=True, element_shape=[length[k]]) def cond_fn(i, partial, outputs): del partial, outputs return i < dynamic_batch_size def body_fn(i, partial, outputs): """Body function for while_loop. Args: i: integer scalar partial: dictionary of Tensor (partially-constructed example) outputs: dictionary of TensorArray Returns: A triple containing the new values of the inputs. """ can_append = True one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) val = val[:tf. reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( tf.size(partial[k]) + tf.size(one_example[k]), length[k])) def false_fn(): return write_packed_example(partial, outputs) def true_fn(): return partial, outputs partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: new_seq = one_example[k][:length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat([ partial[k + '_position'], tf.range(new_seq_len, dtype=tf.int32) ], 0) partial = new_partial return i + 1, partial, outputs i, partial, outputs = \ tf.while_loop( cond_fn, body_fn, (i, partial, outputs), shape_invariants=( tf.TensorShape([]), {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ) ) partial, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: packed[k + '_segmentation'] = (tf.cumsum( tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)) return packed
def slice_first_element_with_from_tensor_high_rank(self, t): ta = tf.TensorArray(dtype=tf.float32, size=STATIC_SIZE, element_shape=[STATIC_SIZE]) ta = ta.unstack(t) return ta.read(0)
def one_step(self, current_state, previous_kernel_results, seed=None): seed = samplers.sanitize_seed(seed) # Retain for diagnostics. start_trajectory_seed, loop_seed = samplers.split_seed(seed) with tf.name_scope(self.name + '.one_step'): state_structure = current_state current_state = tf.nest.flatten(current_state) if (tf.nest.is_nested(state_structure) and (not mcmc_util.is_list_like(state_structure) or len(current_state) != len(state_structure))): # TODO(b/170865194): Support dictionaries and other non-list-like state. raise TypeError( 'NUTS does not currently support nested or ' 'non-list-like state structures (saw: {}).'.format( state_structure)) current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob, seed=start_trajectory_seed) def _copy(v): return v * ps.ones(ps.pad( [2], paddings=[[0, ps.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray(tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack( self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, seed, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree)), body=lambda iter_, seed, state, metastate: self. _loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate, seed), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), loop_seed, initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=(new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy, seed=seed, ) result_state = tf.nest.pack_sequence_as( state_structure, new_step_metastate.candidate_state.state) return result_state, kernel_results
def concat_with_tensorlist_stack(self, a, b): ta = tf.TensorArray(dtype=tf.float32, size=2, element_shape=[]) ta = ta.write(0, a) ta = ta.write(1, b) return ta.stack()
def one_step(self, current_state, previous_kernel_results): with tf.name_scope(self.name + '.one_step'): unwrap_state_list = not tf.nest.is_nested(current_state) if unwrap_state_list: current_state = [current_state] current_target_log_prob = previous_kernel_results.target_log_prob [init_momentum, init_energy, log_slice_sample ] = self._start_trajectory_batched(current_state, current_target_log_prob) def _copy(v): return v * prefer_static.ones(prefer_static.pad( [2], paddings=[[0, prefer_static.rank(v)]], constant_values=1), dtype=v.dtype) initial_state = TreeDoublingState( momentum=init_momentum, state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results.grads_target_log_prob ) initial_step_state = tf.nest.map_structure(_copy, initial_state) if MULTINOMIAL_SAMPLE: init_weight = tf.zeros_like(init_energy) # log(exp(H0 - H0)) else: init_weight = tf.ones_like(init_energy, dtype=TREE_COUNT_DTYPE) candidate_state = TreeDoublingStateCandidate( state=current_state, target=current_target_log_prob, target_grad_parts=previous_kernel_results. grads_target_log_prob, energy=init_energy, weight=init_weight) initial_step_metastate = TreeDoublingMetaState( candidate_state=candidate_state, is_accepted=tf.zeros_like(init_energy, dtype=tf.bool), momentum_sum=init_momentum, energy_diff_sum=tf.zeros_like(init_energy), leapfrog_count=tf.zeros_like(init_energy, dtype=TREE_COUNT_DTYPE), continue_tree=tf.ones_like(init_energy, dtype=tf.bool), not_divergence=tf.ones_like(init_energy, dtype=tf.bool)) # Convert the write/read instruction into TensorArray so that it is # compatible with XLA. write_instruction = tf.TensorArray( TREE_COUNT_DTYPE, size=len(self._write_instruction), clear_after_read=False).unstack(self._write_instruction) read_instruction = tf.TensorArray(tf.int32, size=len(self._read_instruction), clear_after_read=False).unstack( self._read_instruction) current_step_meta_info = OneStepMetaInfo( log_slice_sample=log_slice_sample, init_energy=init_energy, write_instruction=write_instruction, read_instruction=read_instruction) _, _, new_step_metastate = tf.while_loop( cond=lambda iter_, state, metastate: ( # pylint: disable=g-long-lambda (iter_ < self.max_tree_depth) & tf.reduce_any( metastate.continue_tree)), body=lambda iter_, state, metastate: self.loop_tree_doubling( # pylint: disable=g-long-lambda previous_kernel_results.step_size, previous_kernel_results. momentum_state_memory, current_step_meta_info, iter_, state, metastate), loop_vars=(tf.zeros([], dtype=tf.int32, name='iter'), initial_step_state, initial_step_metastate), parallel_iterations=self.parallel_iterations, ) kernel_results = NUTSKernelResults( target_log_prob=new_step_metastate.candidate_state.target, grads_target_log_prob=( new_step_metastate.candidate_state.target_grad_parts), momentum_state_memory=previous_kernel_results. momentum_state_memory, step_size=previous_kernel_results.step_size, log_accept_ratio=tf.math.log( new_step_metastate.energy_diff_sum / tf.cast(new_step_metastate.leapfrog_count, dtype=new_step_metastate.energy_diff_sum.dtype)), leapfrogs_taken=(new_step_metastate.leapfrog_count * self.unrolled_leapfrog_steps), is_accepted=new_step_metastate.is_accepted, reach_max_depth=new_step_metastate.continue_tree, has_divergence=~new_step_metastate.not_divergence, energy=new_step_metastate.candidate_state.energy) result_state = new_step_metastate.candidate_state.state if unwrap_state_list: result_state = result_state[0] return result_state, kernel_results
def _sample_paths(self, times, times_size, current_log_spot, current_var, num_samples, random_type, keep_mask, seed, tolerance): """Returns a sample of paths from the process.""" # Note: all the notations below are the same as in [1]. # Add zeros as a starting location dt = times[1:] - times[:-1] kappa, theta, epsilon, rho = _get_parameters( # pylint: disable=unbalanced-tuple-unpacking times + tf.reduce_min(dt) / 2, self._kappa, self._theta, self._epsilon, self._rho) cond_fn = lambda i, *args: i < tf.size(dt) def body_fn(i, written_count, current_var, current_log_spot, vol_paths, log_spot_paths): """Simulate Heston process to the next time point.""" time_step = dt[i] def _next_vol_fn(): return _update_variance(i, kappa[i], theta[i], epsilon[i], rho[i], current_var, time_step, num_samples, random_type, seed) # Do not update variance if `time_step > tolerance` next_vol = tf.cond( time_step > tolerance, lambda: _next_vol_fn(), # pylint: disable=unnecessary-lambda lambda: current_var) def _next_log_spot_fn(): return _update_log_spot(i, kappa[i], theta[i], epsilon[i], rho[i], current_var, next_vol, current_log_spot, time_step, num_samples, random_type, seed) # Do not update state if `time_step > tolerance` next_log_spot = tf.cond( time_step > tolerance, lambda: _next_log_spot_fn(), # pylint: disable=unnecessary-lambda lambda: current_log_spot) vol_paths = tf.cond( keep_mask[i + 1], lambda: vol_paths.write(written_count, next_vol), lambda: vol_paths) log_spot_paths = tf.cond( keep_mask[i + 1], lambda: log_spot_paths.write(written_count, next_log_spot), lambda: log_spot_paths) written_count += tf.cast(keep_mask[i + 1], dtype=tf.int32) return (i + 1, written_count, next_vol, next_log_spot, vol_paths, log_spot_paths) log_spot_paths = tf.TensorArray(dtype=self._dtype, size=times_size) vol_paths = tf.TensorArray(dtype=self._dtype, size=times_size) _, _, _, _, vol_paths, log_spot_paths = tf.compat.v2.while_loop( cond_fn, body_fn, (0, 0, current_var, current_log_spot, vol_paths, log_spot_paths)) return tf.stack([ tf.transpose(log_spot_paths.stack()), tf.transpose(vol_paths.stack()) ], -1)