def _compute_is_bus_day_table(self): """Computes and caches "is business day" table.""" if self._table_cache.is_bus_day is not None: return self._table_cache.is_bus_day with tf.init_scope(): ordinals = tf.range(self._ordinal_offset, self._ordinal_offset + self._calendar_size) # Apply weekend mask week_days = (ordinals - 1) % 7 is_holiday = tf.gather(self._weekend_mask, week_days) # Apply holidays if self._holidays is not None: indices = self._holidays.ordinal() - self._ordinal_offset ones_at_indices = tf.scatter_nd( tf.expand_dims(indices, axis=-1), tf.ones_like(indices), is_holiday.shape) is_holiday = tf.bitwise.bitwise_or(is_holiday, ones_at_indices) # Add a business day at the beginning and at the end, i.e. at 31 Dec of # start_year-1 and at 1 Jan of end_year+1. This trick is to avoid dealing # with special cases on boundaries. # For example, for Following and Preceding conventions we'd need a special # value that means "unknown" in the tables. More complicated conventions # then combine the Following and Preceding tables, and would need special # treatment of the "unknown" values. # With these "fake" business days, all computations are automatically # correct, unless we land on those extra days - for this reason we add # assertions in all API calls before returning. is_bus_day_table = tf.concat([[1], 1 - is_holiday, [1]], axis=0) self._table_cache.is_bus_day = is_bus_day_table return is_bus_day_table
def _get_q_slice(q, k, ind, b=None, batch_shape=None): """Returns `q1[i]` or `q0[j]` for a batch of indices `i` or `j`.""" q_ind = tf.concat([ind, tf.expand_dims(tf.gather_nd(k, ind), -1)], axis=1) b_updates = tf.gather_nd(q, q_ind) if b is None: return tf.scatter_nd(ind, b_updates, batch_shape) return tf.tensor_scatter_nd_update(b, ind, b_updates)
def _assemble(self, global_scatter=True): a_local_values = self._local_stiffness_matrix_assembler() b_local_values = self._local_load_vector_assembler() # Don't scatter into global matrices if not global_scatter: return a_local_values, b_local_values global_stiffness_indices = itertools.chain.from_iterable( [itertools.product(row, row) for row in self.mesh.elements]) if len(self.batch_shape) == 0: # add an extra leading dimension onto a_local a_local_values = a_local_values[tf.newaxis, ...] batch_shape = [1] else: batch_shape = self.batch_shape # flatten a_local_values flat_shape = tf.concat( (batch_shape, [a_local_values.shape[-3] * a_local_values.shape[-1]**2]), axis=0) a_local_values = tf.reshape(a_local_values, flat_shape, name='flatten_local_stiffness_matrix') # unpack global_stiffness_indices global_stiffness_indices = [*global_stiffness_indices] A = tf.map_fn( lambda x: tf.scatter_nd( global_stiffness_indices, x, shape=[self.mesh.n_nodes, self.mesh.n_nodes]), a_local_values) A = tf.squeeze(A) global_load_indices = itertools.chain.from_iterable([ zip(row, np.zeros(3, dtype=np.intp)) for row in self.mesh.elements ]) b = tf.scatter_nd([*global_load_indices], tf.reshape(b_local_values, [-1], name='flatten_local_load'), shape=[self.mesh.n_nodes, 1], name='scatter_nd_to_global_load') return A, b
def testIndexedSlices(self): dtype = tf.int64 iss = tf.IndexedSlices(values=tf.ones([2, 3], dtype=dtype), indices=tf.constant([1, 9]), dense_shape=[10, 3]) a = array_ops.array(iss, copy=False) expected = tf.scatter_nd([[1], [9]], tf.ones([2, 3], dtype=dtype), [10, 3]) self.assertAllEqual(expected, a)
def get_output_shape(self, input_shape): if self._output_shape is not None: return self._output_shape axis = tf.math.mod(self._axis, tf.shape(input_shape)[0]) return input_shape + tf.scatter_nd( indices=[[axis]], updates=[len(self._quantiles) - input_shape[axis]], shape=tf.shape(input_shape))
def compute_output_shape(self, input_shape): output_shape = input_shape if self._topk is not None: axis = tf.math.mod(self._axis, tf.shape(input_shape)[0]) output_shape = input_shape + tf.scatter_nd( indices=[[axis]], updates=[self._topk - input_shape[axis]], shape=tf.shape(input_shape)) return tf.TensorShape(output_shape)
def moveaxis(a, source, destination): # pylint: disable=missing-docstring """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" if not source and not destination: return a a = asarray(a).data if isinstance(source, int): source = (source, ) if isinstance(destination, int): destination = (destination, ) a_rank = utils._maybe_static(tf.rank(a)) # pylint: disable=protected-access def _correct_axis(axis, rank): if axis < 0: return axis + rank return axis source = tuple(_correct_axis(axis, a_rank) for axis in source) destination = tuple(_correct_axis(axis, a_rank) for axis in destination) if a.shape.rank is not None: perm = [i for i in range(a_rank) if i not in source] for dest, src in sorted(zip(destination, source)): assert dest <= len(perm) perm.insert(dest, src) else: r = tf.range(a_rank) def _remove_indices(a, b): """Remove indices (`b`) from `a`.""" items = tf.unstack(tf.sort(tf.stack(b)), num=len(b)) i = 0 result = [] for item in items: result.append(a[i:item]) i = item + 1 result.append(a[i:]) return tf.concat(result, 0) minus_sources = _remove_indices(r, source) minus_dest = _remove_indices(r, destination) perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank]) perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1), source) a = tf.transpose(a, perm) return utils.tensor_to_ndarray(a)
def _prepare_grid(times, times_grid, *params): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. times_grid: An optional rank 1 `Tensor` representing time discretization grid. If `times` are not on the grid, then the nearest points from the grid are used. *params: Parameters of the Heston model. Either scalar `Tensor`s of the same `dtype` or instances of `PiecewiseConstantFunc`. Returns: Tuple `(all_times, mask)`. `all_times` is a 1-D real `Tensor` containing all points from 'times`, the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`, and jump locations of piecewise constant parameters The `Tensor` is sorted in ascending order and may contain duplicates. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. """ if times_grid is None: additional_times = [] for param in params: if hasattr(param, 'is_piecewise_constant'): if param.is_piecewise_constant: # Flatten all jump locations additional_times.append( tf.reshape(param.jump_locations(), [-1])) zeros = tf.constant([0], dtype=times.dtype) all_times = tf.concat([zeros] + [times] + additional_times, axis=0) all_times = tf.sort(all_times) time_indices = tf.searchsorted(all_times, times, out_type=tf.int32) else: all_times = times_grid time_indices = tf.searchsorted(times_grid, times, out_type=tf.int32) # Adjust indices to bring `times` closer to `times_grid`. times_diff_1 = tf.gather(times_grid, time_indices) - times times_diff_2 = tf.gather(times_grid, tf.nn.relu(time_indices - 1)) - times time_indices = tf.where( tf.math.abs(times_diff_2) > tf.math.abs(times_diff_1), time_indices, tf.nn.relu(time_indices - 1)) # Create a boolean mask to identify the iterations that have to be recorded. mask = tf.scatter_nd(indices=tf.expand_dims(tf.cast(time_indices, dtype=tf.int64), axis=1), updates=tf.fill(tf.shape(times), True), shape=tf.shape(all_times, out_type=tf.int64)) return all_times, mask
def _get_reset_state_indices(): reset_indices_obs = tf.nest.map_structure( lambda t: tf.gather_nd(t, reset_indices), observation) # shape: [num_indices_to_reset, ...] reset_indices_state = self.get_initial_state( reset_indices_obs, batch_size=tf.shape(reset_indices)[0]) # Scatter tensors in `reset_indices_state` to shape: [num_timesteps, # batch_size, ...] return tf.nest.map_structure( lambda reset_tensor: tf.scatter_nd(indices=reset_indices, updates=reset_tensor, shape=done.shape.as_list() + reset_tensor.shape.as_list( )[1:]), reset_indices_state)
def test_sanity_check_sweep_over_features(self): num_outputs = 100 num_features = 3 design_matrix, true_weights, targets = self.evaluate( self._random_regression_task( num_outputs=num_outputs, num_features=num_features, # Specify weights with a clear sparsity pattern. weights=tf.convert_to_tensor([10., 0., -10.]), seed=test_util.test_seed())) sampler = dynamic_spike_and_slab.DynamicSpikeSlabSampler( design_matrix, # Ensure the probability of keeping an irrelevant feature is tiny. nonzero_prior_prob=1e-6) initial_state = sampler._initialize_sampler_state( targets=targets, nonzeros=tf.convert_to_tensor([True, True, True]), observation_noise_variance=1.) final_state = self.evaluate( sampler._resample_all_features(initial_state, seed=test_util.test_seed())) # Check that we recovered the true sparsity pattern and approximate weights. weights_posterior_precision = (sampler.x_transpose_x + sampler.weights_prior_precision) conditional_weights_mean = _compute_conditional_weights_mean( final_state.nonzeros, weights_posterior_precision, final_state.x_transpose_y) self.assertAllEqual(final_state.nonzeros, [True, False, True]) indices = tf.where(final_state.nonzeros) conditional_weights_mean = tf.scatter_nd(indices, conditional_weights_mean, true_weights.shape) self.assertAllClose(conditional_weights_mean, true_weights, rtol=0.05, atol=0.15) posterior = sampler._get_conditional_posterior(final_state) posterior_variances, posterior_weights = self.evaluate( posterior.sample(seed=test_util.test_seed())) self.assertAllFinite(posterior_variances) self.assertAllFinite(posterior_weights)
def test_categorical_resampler_chi2(self): # Test categorical resampler using chi-squared test. num_probs = 50 num_distributions = 3 unnormalized_probs = tfd.Uniform( low=np.float64(0), high=np.float64(1.)).sample([num_distributions, num_probs], seed=test_util.test_seed()) probs = unnormalized_probs / tf.reduce_sum( unnormalized_probs, axis=-1, keepdims=True) # chi-squared test is valid as long as `num_samples` is # large compared to `num_probs`. num_particles = 10000 num_samples = 2 sample = resample_independent( tf.math.log(dist_util.move_dimension(probs, source_idx=-1, dest_idx=0)), num_particles, [num_samples], seed=test_util.test_seed()) sample = dist_util.move_dimension(sample, source_idx=0, dest_idx=-1) # TODO(dpiponi): reimplement this test in vectorized form rather than with # loops. for sample_index in range(num_samples): for prob_index in range(num_distributions): counts = tf.scatter_nd( indices=sample[sample_index, prob_index][:, tf.newaxis], updates=tf.ones(num_particles, dtype=tf.int32), shape=[num_probs]) expected_samples = probs[prob_index] * num_particles chi2 = tf.reduce_sum( (tf.cast(counts, tf.float64) - expected_samples)**2 / expected_samples, axis=-1) self.assertAllLess( tfd.Chi2(df=np.float64(num_probs - 1)).cdf(chi2), 0.9999)
def _get_endpoint_a(i, j, q1_i, q0_j, batch_shape): """Determine the beginning of the interval, `a`.""" # if i < 0: a = q0[j] i_lt_0 = tf.less(i, 0) ind = tf.where(i_lt_0) a_update = tf.gather_nd(q0_j, ind) a = tf.scatter_nd(ind, a_update, batch_shape) # elif j < 0: a = q1[i] j_lt_0 = tf.less(j, 0) ind = tf.where(j_lt_0) a_update = tf.gather_nd(q1_i, ind) a = tf.tensor_scatter_nd_update(a, ind, a_update) # else: a = max(q0[j], q1[i]) ind = tf.where(~(i_lt_0 | j_lt_0)) q_max = tf.maximum(q0_j, q1_i) a_update = tf.gather_nd(q_max, ind) a = tf.tensor_scatter_nd_update(a, ind, a_update) return a
def test_categorical_resampler_chi2(self): # Test categorical resampler using chi-squared test. if self.use_xla and tf.executing_eagerly(): self.skipTest('No need to test XLA under all execution regimes.') num_probs = 50 num_distributions = 3 unnormalized_probs = tfd.Uniform(low=self.dtype(0), high=self.dtype(1.)).sample( [num_distributions, num_probs], seed=42) probs = unnormalized_probs / tf.reduce_sum( unnormalized_probs, axis=-1, keepdims=True) # chi-squared test is valid as long as `num_samples` is # large compared to `num_probs`. num_particles = 10000 num_samples = 2 sample = self.maybe_compiler(resample_independent)(tf.math.log( dist_util.move_dimension(probs, source_idx=-1, dest_idx=0)), num_particles, [num_samples]) sample = dist_util.move_dimension(sample, source_idx=0, dest_idx=-1) # TODO(dpiponi): reimplement this test in vectorized form rather than with # loops. for sample_index in range(num_samples): for prob_index in range(num_distributions): counts = tf.scatter_nd(indices=sample[sample_index, prob_index][:, tf.newaxis], updates=tf.ones(num_particles, dtype=tf.int32), shape=[num_probs]) expected_samples = probs[prob_index] * num_particles chi2 = tf.reduce_sum( (tf.cast(counts, self.dtype) - expected_samples)**2 / expected_samples, axis=-1) self.assertAllLess( tfd.Chi2(df=self.dtype(num_probs - 1)).cdf(chi2), 0.99995)
def _solve_dirichlet(self): # check the boundary conditions are correct if self.domain.boundary.boundary_condition_type != 'Dirichlet': raise ValueError( "_solve_dirichlet must be used with Dirichlet boundary conditions." ) A, b = self._assemble() # work out if any batching needs to be done if len(self.batch_shape) == 0: # add an extra leading dimension onto A A = A[tf.newaxis, ...] batch_shape = [1] # false batch shape, squeezed out by end else: batch_shape = self.batch_shape # get the interior of A global_stiffness_interior_indices = [ *itertools.product(self.mesh.interior_node_indices, repeat=2) ] Ainterior = tf.map_fn( lambda x: tf.reshape( tf.gather_nd(x, global_stiffness_interior_indices), [ len(self.mesh.interior_node_indices), len(self.mesh.interior_node_indices) ]), A) b_interior = tf.gather_nd(b, [ *zip(self.mesh.interior_node_indices, [0] * len(self.mesh.interior_node_indices)) ]) interior_bound_indices = [ *itertools.product(self.mesh.interior_node_indices, self.mesh.boundary_node_indices) ] Aint_bnd = tf.map_fn( lambda x: tf.reshape(tf.gather_nd(x, interior_bound_indices), [ len(self.mesh.interior_node_indices), len(self.mesh.boundary_node_indices) ]), A) bnd_node_indices = np.array(self.mesh.boundary_node_indices, dtype=np.intp) int_node_indices = np.array(self.mesh.interior_node_indices, dtype=np.intp) # get the value on the boundary g = self.domain.boundary.g # convert the stiffness matrices to operators for batch matmul Ainterior_op = tf.linalg.LinearOperatorFullMatrix(Ainterior) Aint_bnd_op = tf.linalg.LinearOperatorFullMatrix(Aint_bnd) # add the fixed dirichlet conditions to sol # ToDo: Batch boundary values sol = tf.scatter_nd(bnd_node_indices[:, None], g, shape=[self.mesh.n_nodes, 1]) b_ = b_interior[..., tf.newaxis] - Aint_bnd_op.matmul(g) sol_interior = Ainterior_op.solve(b_) # sol_interior has a batched shape [b, n_interior_nodes, 1] return tf.squeeze( tf.map_fn( lambda x: tf.tensor_scatter_nd_add( sol, int_node_indices[:, None], x), sol_interior) )[..., tf. newaxis] # kills pesduo-batch dimensions, but keeps output a vector
def _accumulating_for_loop(body_fn, initial_state, params, num_iterations, name=None): """Version of for_loop with multiple values of num_iterations.""" # Every tensor in nested tensors (state and Jacobian) gets an extra # "accumulating" dimension in front. Functions _create_accumulators etc. below # help to work with this dimension. with tf.name_scope(name or "accumulating_for_loop"): max_iterations = tf.math.reduce_max(num_iterations) acc_size = num_iterations.shape[0] # num_iteration = [2, 5] -> mask = [0, 0, 1, 0, 0, 1]. Tells when we should # increment acc index before writing. Last element won't be used (i = 0..4). mask = tf.scatter_nd(indices=tf.expand_dims(num_iterations, axis=-1), updates=tf.ones_like(num_iterations), shape=(max_iterations + 1,)) n = len(initial_state) @tf.custom_gradient def inner(*args): initial_state, params = args[:n], args[n:] def while_cond(i, acc_index, acc_state, acc_jac): del acc_index, acc_state, acc_jac return i < max_iterations def while_body(i, acc_index, acc_state, acc_jac): state = _read_from_accumulators(acc_state, acc_index) jac = _read_from_accumulators(acc_jac, acc_index) with tf.GradientTape(persistent=True) as tape: tape.watch(state) tape.watch(params) next_state = tuple(body_fn(i, state)) step_jac = _compute_step_jacobian(state, next_state, params, tape) next_jac = _multiply_jacobians(step_jac, jac) acc_index += mask[i] acc_state = _write_to_accumulators(acc_state, next_state, acc_index) acc_jac = _write_to_accumulators(acc_jac, next_jac, acc_index) return i + 1, acc_index, acc_state, acc_jac initial_acc_state = _create_accumulators(initial_state, acc_size) initial_acc_state = _write_to_accumulators(initial_acc_state, initial_state, 0) initial_jac = _make_unit_jacobian(initial_state, params) initial_acc_jac = _create_accumulators(initial_jac, acc_size) initial_acc_jac = _write_to_accumulators(initial_acc_jac, initial_jac, 0) loop_vars = (0, 0, initial_acc_state, initial_acc_jac) _, _, final_acc_state, final_acc_jac = tf.compat.v2.while_loop( while_cond, while_body, loop_vars=loop_vars, maximum_iterations=max_iterations) def gradient(*ws): # Same as in for_loop, except we need to sum over the accumulating # dimension. E.g. if x = for_loop(... num_iterations=[2, 5]) and # y = 2*x[0] + 3*x[1], then taking gradient of y will lead to ws having # coeffs 2 and 3 in the acc dimension, and we should sum over it. ws = [tf.expand_dims(w, axis=-2) for w in ws] ws = [ws] # expand dims on block level as well. js, jp = final_acc_jac ws_js, ws_jp = _block_matmul(ws, js), _block_matmul(ws, jp) ws_js, ws_jp = ws_js[0], ws_jp[0] ws_js = [tf.squeeze(t, axis=-2) for t in ws_js] ws_jp = [tf.squeeze(t, axis=[-2, -1]) for t in ws_jp] # Sum over acc axis. ws_js = [tf.math.reduce_sum(t, axis=0) for t in ws_js] # ws_jp should be 0-dimensional ws_jp = [tf.math.reduce_sum(t) for t in ws_jp] return ws_js + ws_jp return final_acc_state, gradient # tf.custom_gradient can only handle a flat sequence of args. args = tuple(initial_state + params) return inner(*args)
def repeat(a, repeats, axis=None): """Repeat elements of the array along specified axes. Args: a: array_like. Could be an ndarray, a Tensor or any object that can be converted to a Tensor using `tf.convert_to_tensor`. repeats: 0-d or 1-d array_like. The number of times each element along `axis` will be repeated. If this has size 1, each element along the axis is repeated the same number of times. axis: Optional. The axis along which to repeat. If None, the input array is flattened. Returns: An ndarray with same type as `a`. Raises: ValueError: If `repeats` has rank > 1 or an incompatible shape. """ a = array_creation.asarray(a) repeats = array_creation.asarray(repeats) if repeats.ndim > 1: raise ValueError('repeats must be a scalar or 1-d array.') repeats = ravel(repeats) # Convert to 1-d array. # As per documentation, if axis is None, the input is flattened # and a flattened output is returned. if axis is None: a = ravel(a) axis = 0 elif axis < 0: axis += a.ndim # Broadcast repeats to match shape of axis. if len(repeats) == 1: repeats = utils.tensor_to_ndarray( tf.tile(repeats.data, [a.shape[axis]])) if a.shape[axis] != len(repeats): raise ValueError( 'Shape mismatch. `repeats` expected to have shape ({},)' ' but has ({},)'.format(a.shape[axis], len(repeats))) # Example: # # a: [[1, 2, 3], # [4, 5, 6]] # axis: 1 # repeats: [3, 1, 2] # Output: [[1, 1, 1, 2, 3, 3], # [4, 4, 4, 5, 6, 6]] # # Algorithm: # 1. Calculate cumulative sum of repeats. repeats_cumsum = cumsum(repeats) # [3, 4, 6] # 2. Use `scatter_nd` to generate an indices list for use in `tf.gather`. scatter_indices_t = repeats_cumsum[:-1].data # [3, 4] scatter_indices_t = tf.expand_dims(scatter_indices_t, 1) # [[3], [4]] scatter_updates_t = tf.ones([len(repeats) - 1], dtype=tf.int32) # [1, 1] scatter_shape_t = ravel(repeats_cumsum[-1]).data # [6] # `tf.scatter_nd([[3], [4]], [1, 1], [6])` -> `[0, 0, 0, 1, 1, 0]` indices_t = tf.scatter_nd(scatter_indices_t, scatter_updates_t, scatter_shape_t) indices_t = tf.cumsum(indices_t) # [0, 0, 0, 1, 2, 2] # 3. Use `tf.gather` to gather indices along `axis`. result_t = tf.gather(a, indices_t, axis=axis) return utils.tensor_to_ndarray(result_t)
def _get_spl_tensor(bins: int, frequencies: tf.Tensor, spls: tf.Tensor) -> tf.Tensor: indices = tf.expand_dims(tf.cast(frequencies, dtype=tf.int32), axis=1) return tf.scatter_nd(indices, updates=spls, shape=tf.constant([bins]))
def op(x, kernel): input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32) x = tf.convert_to_tensor(x, dtype=input_dtype, name='x') kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel') batch_shape, event_shape = ps.split(ps.shape(x), num_or_size_splits=[-1, 3]) xh, xw, c_in = ps.unstack(event_shape, num=3) kernel_shape = ps.shape(kernel) c_out = kernel_shape[-1] kernel_batch = kernel_shape[:-2] assertions = _maybe_validate_input_shapes( kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw, validate_args=validate_args) with tf.control_dependencies(assertions): # If the kernel does not have batch shape, fall back to # `conv2d_transpose` (unless dilations > 1, which is not implemented in # `conv2d_transpose`). if (tf.get_static_value(ps.rank(kernel)) == 2 and all(d == 1 for d in dilations)): return _call_conv2d_transpose(x, kernel=kernel, filter_shape=filter_shape, strides=(strides, ) * rank, padding=padding, dilations=dilations, c_out=c_out, batch_shape=batch_shape, event_shape=event_shape) n = ps.maximum(0, ps.rank(x) - 3) paddings = ps.pad(padding_vals, paddings=[[n, 1], [0, 0]], constant_values=0) x_pad = tf.pad(x, paddings=paddings, constant_values=0) x_pad_shape = ps.shape(x_pad)[:-3] flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1) flat_x = tf.reshape(x_pad, shape=flat_shape) idx, s = im2row_index( (xh + tf.reduce_sum(padding_vals[0]), xw + tf.reduce_sum(padding_vals[1]), c_in), block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations) x_ = tf.gather(flat_x, indices=idx, axis=-1) im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0)) # Add channels to subkernel indices idx_event = event_ind * [[c_in, 1]] idx_event_channels = (idx_event[tf.newaxis] + tf.stack( [ps.range(c_in), tf.zeros( (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :]) idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0) idx_event_broadcast = tf.broadcast_to( idx_event, shape=ps.concat( [kernel_batch, ps.shape(idx_event)], axis=0)) # Add cartesian product of batch indices, since scatter_nd can only be # applied to leading dimensions. idx_batch = tf.stack(tf.meshgrid(*[ ps.range(b_, delta=1, dtype=dtype) for b_ in tf.unstack(kernel_batch) ], indexing='ij'), axis=ps.size(kernel_batch)) idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros( (ps.shape(idx_event)[0], 1), dtype=dtype) idx_kernel = tf.concat( [idx_batch_broadcast, idx_event_broadcast], axis=-1) kernel_mat = tf.scatter_nd( idx_kernel, updates=kernel, shape=ps.cast(ps.concat([ kernel_batch, [sub_fh * sub_fw * c_in, strides**2, c_out] ], axis=0), dtype=dtype)) kernel_mat = tf.reshape( kernel_mat, shape=ps.concat( [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]], axis=0)) kernel_mat = kernel_mat[..., tf.newaxis, :, :] out = tf.matmul(im_x, kernel_mat) broadcast_batch_shape = ps.broadcast_shape( batch_shape, kernel_batch) if strides > 1: tot_size = tf.reduce_prod(broadcast_batch_shape) flat_out = tf.reshape(out, shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0)) out = tf.nn.depth_to_space(flat_out, block_size=strides) if padding == 'VALID': out_height = fh + strides * (xh - 1) out_width = fw + strides * (xw - 1) elif padding == 'SAME': out_height = xh * strides out_width = xw * strides out = out[..., truncate_top:truncate_top + out_height, truncate_left:truncate_left + out_width, :] out = tf.reshape( out, shape=ps.concat([ broadcast_batch_shape, [out_height, out_width, c_out] ], axis=0)) return out
def get_positions_shift(positions: tf.Tensor, exclusive: bool) -> tf.Tensor: index_2d = _get_2d_index(mention_batch_positions, positions) return tf.cumsum(tf.scatter_nd(index_2d, mention_mask, old_shape), axis=1, exclusive=exclusive)
def _call_sample_n(self, *args, **kwargs): xs = super()._call_sample_n(*args, **kwargs) return tf.scatter_nd(indices=self._indices, updates=xs, shape=[self._size])
def linear_interpolation_operator(self, index_points): """ Returns the linear operator that carries out interpolation of the solution at node points. Parameters ---------- index_points : array, shape=[..., nobs, 1] points at which we want to interpolate Returns ------- O : batched array, shape = [..., nobs, mesh.npoints] linear operator such that O @ u interpolates u[mesh.points] onto index_points .. note:: The matrix of the operator returned will have 2 entries per row, and so will be sparse when `mesh.npoints >> index_points.shape[-1]`. """ # we need to find which element each of index points is in # we exploit the fact that the interval mesh is order and then # broadcast > is_greater = tf.squeeze( index_points[:, tf.newaxis, :] > self.nodes[tf.newaxis, ...]) # find the max along each row, corresponding to an index point, to find # the last interval [ta[i], tb[i]] for which index_point[m] > ta[i] index_point_elements = tf.argmax( tf.cast(is_greater, tf.int32) * tf.range(0, self.n_nodes, 1)[tf.newaxis, :], axis=-1) ta = tf.gather(self.nodes, index_point_elements) tb = tf.gather(self.nodes, index_point_elements + 1) index_point_element_volumes = tf.gather( self.element_volumes, index_point_elements)[:, tf.newaxis] # now we have found which element the index point is in we # evaluate the local basis functions on these elements phi1 = (tb - index_points) / index_point_element_volumes phi2 = (index_points - ta) / index_point_element_volumes # indices to update Oa #[*zip(np.arange(index_points.shape[-2]), index_point_elements)] Oa_indices = tf.concat([ np.arange(index_points.shape[-2])[:, tf.newaxis], index_point_elements[:, tf.newaxis] ], axis=-1) Oa = tf.scatter_nd(Oa_indices, tf.squeeze(phi1), shape=[index_points.shape[-2], self.n_nodes]) # incremenet the index_point_elements by one, taking advantage of the sorting Ob_indices = tf.concat([ np.arange(index_points.shape[-2])[:, tf.newaxis], (index_point_elements + 1)[:, tf.newaxis] ], axis=-1) Ob = tf.scatter_nd(Ob_indices, tf.squeeze(phi2), shape=[index_points.shape[-2], self.n_nodes]) return Oa + Ob
def mask_mentions_and_tokens_tf( text_ids: tf.Tensor, text_mask: tf.Tensor, dense_span_starts: tf.Tensor, dense_span_ends: tf.Tensor, non_mention_mask_rate: float, mention_mask_rate: float, max_mlm_targets: int, mask_token_id: int, vocab_size: int, random_replacement_prob: float = 0.1, identity_replacement_prob: float = 0.1, ) -> Dict[str, tf.Tensor]: """Randomly masks whole mentions and random tokens up to a maximum. First, mentions are masked according to mention mask rate. If a mention is masked, all tokens in the mention are replaced by the mask token. If the passage has at least one mention and the mention rask rate is greater than zero, we mask at least one mention. After masking mentions, if there are fewer masked tokens than maximum mlm targets, we additionally mask non-mention words. TODO: If a token in a word is masked, all tokens in the word are masked. Some proportion of targets are not masked to ameliorate pretrain-finetune mismatch. If there are insufficient masked tokens, the target array is padded up to max targets. Args: text_ids: [seq_length] tensor with token ids. text_mask: [seq_length] tensor with 1s for tokens and 0 for padding. dense_span_starts: [seq_length] tensor with 1s for mention start positions and 0 otherwise. dense_span_ends: [seq_length] tensor with 1s for mention end positions and 0 otherwise. non_mention_mask_rate: percentage of non mention tokens to be masked. mention_mask_rate: percentage of mentions to be masked. max_mlm_targets: total number of mlm targets. mask_token_id: token id for mask token. vocab_size: vocabulary size. random_replacement_prob: probability that to-be-masked token will be replaced with a random token instead of [MASK]. identity_replacement_prob: probability that to-be-masked token will be replaced with itself instead of [MASK]. Returns: Dictionary with masked text, mask positions, target ids, target weights. """ # Mask mentions mention_start_positions = non_zero_1d(dense_span_starts) mention_end_positions = non_zero_1d(dense_span_ends) mention_masked_positions = mask_tokens_by_spans(text_ids, mention_start_positions, mention_end_positions, mention_mask_rate, max_mlm_targets) dense_is_mention = get_dense_is_inside_for_dense_spans( dense_span_starts, dense_span_ends) dense_is_not_mention = 1 - dense_is_mention dense_is_not_mention = dense_is_not_mention * text_mask # Mask non-mentions non_mention_start_positions = non_zero_1d(dense_is_not_mention) # TODO(urikz): Implement whole-word masking non_mention_end_positions = non_mention_start_positions non_mention_masked_positions = mask_tokens_by_spans( text_ids, non_mention_start_positions, non_mention_end_positions, non_mention_mask_rate, max_mlm_targets - tf.shape(mention_masked_positions)[0]) # Merge masked positions for mention and non-mention tokens mlm_target_positions = tf.concat( [mention_masked_positions, non_mention_masked_positions], -1) n_mlm_target_positions = tf.shape(mlm_target_positions) # Get target IDs, weights and other features mlm_target_ids = tf.gather(text_ids, mlm_target_positions) mlm_target_weights = tf.ones(n_mlm_target_positions, dtype=tf.int64) mlm_target_is_mention = tf.ones(tf.shape(mention_masked_positions), dtype=tf.int64) seq_length = tf.shape(text_ids)[0] dense_is_masked = sparse_to_dense_1d(mlm_target_positions, seq_length) # Replace masked tokens with [MASK], random or original tokens. replacement_scores = tf.random.uniform(n_mlm_target_positions) replacement_tokens = tf.where( replacement_scores > random_replacement_prob + identity_replacement_prob, # replace tokens with [MASK] tf.cast(tf.fill(n_mlm_target_positions, value=mask_token_id), dtype=tf.int64), tf.where( replacement_scores > random_replacement_prob, # keep original mlm_target_ids, # replace with random token tf.random.uniform(n_mlm_target_positions, maxval=vocab_size, dtype=tf.int64))) replacement_positions = tf.expand_dims(mlm_target_positions, 1) # Indicies should be tf.int32 only. replacement_positions = tf.cast(replacement_positions, tf.int32) replacement_tokens = tf.scatter_nd(replacement_positions, replacement_tokens, tf.shape(text_ids)) masked_text_ids = (text_ids * (1 - dense_is_masked) + replacement_tokens * dense_is_masked) return { 'masked_text_ids': masked_text_ids, 'mlm_target_positions': dynamic_padding_1d(mlm_target_positions, max_mlm_targets), 'mlm_target_ids': dynamic_padding_1d(mlm_target_ids, max_mlm_targets), 'mlm_target_weights': dynamic_padding_1d(mlm_target_weights, max_mlm_targets), 'mlm_target_is_mention': dynamic_padding_1d(mlm_target_is_mention, max_mlm_targets), 'dense_is_masked': dense_is_masked, }
def segment_diff(x, segment_ids, order=1, exclusive=False, dtype=None, name=None): """Computes difference of successive elements in a segment. For a complete description of segment_* ops see documentation of `tf.segment_max`. This op extends the `diff` functionality to segmented inputs. The behaviour of this op is the same as that of the op `diff` within each segment. The result is effectively a concatenation of the results of `diff` applied to each segment. ## Example ```python x = tf.constant([2, 5, 1, 7, 9] + [32, 10, 12, 3] + [4, 8, 5]) segments = tf.constant([0, 0, 0, 0, 0] + [1, 1, 1, 1] + [2, 2, 2]) # First order diff. Expected result: [3, -4, 6, 2, -22, 2, -9, 4, -3] dx1 = segment_diff( x, segment_ids=segments, order=1, exclusive=True) # Non-exclusive, second order diff. # Expected result: [2, 5, -1, 2, 8, 32, 10, -20, -7, 4, 8, 1] dx2 = segment_diff( x, segment_ids=segments, order=2, exclusive=False) ``` Args: x: A rank 1 `Tensor` of any dtype for which arithmetic operations are permitted. segment_ids: A `Tensor`. Must be one of the following types: int32, int64. A 1-D tensor whose size is equal to the size of `x`. Values should be sorted and can be repeated. order: Positive Python int. The order of the difference to compute. `order = 1` corresponds to the difference between successive elements. Default value: 1 exclusive: Python bool. See description above. Default value: False dtype: Optional `tf.Dtype`. If supplied, the dtype for `x` to use when converting to `Tensor`. Default value: None which maps to the default dtype inferred by TF. name: Python `str` name prefixed to Ops created by this class. Default value: None which is mapped to the default name 'segment_diff'. Returns: diffs: A `Tensor` of the same dtype as `x`. Assuming that each segment is of length greater than or equal to order, if `exclusive` is True, then the size is `n-order*k` where `n` is the size of x, `k` is the number of different segment ids supplied if `segment_ids` is not None or 1 if `segment_ids` is None. If any of the segments is of length less than the order, then the size is: `n-sum(min(order, length(segment_j)), j)` where the sum is over segments. If `exclusive` is False, then the size is `n`. """ with tf.compat.v1.name_scope(name, default_name='segment_diff', values=[x]): x = tf.convert_to_tensor(x, dtype=dtype) raw_diffs = diff_ops.diff(x, order=order, exclusive=exclusive) if segment_ids is None: return raw_diffs # If segment ids are supplied, raw_diffs are incorrect at locations: # p, p+1, ... min(p+order-1, m_p-1) where p is the index of the first # element of a segment other than the very first segment (which is # already correct). m_p is the segment length. # Find positions where the segments begin. has_segment_changed = tf.concat( [[False], tf.not_equal(segment_ids[1:] - segment_ids[:-1], 0)], axis=0) # Shape [k, 1] segment_start_index = tf.cast(tf.where(has_segment_changed), dtype=tf.int32) segment_end_index = tf.concat([ tf.reshape(segment_start_index, [-1])[1:], [tf.size(segment_ids)] ], axis=0) segment_end_index = tf.reshape(segment_end_index, [-1, 1]) # The indices of locations that need to be adjusted. This needs to be # constructed in steps. First we generate p, p+1, ... p+order-1. # Shape [num_segments-1, order] fix_indices = (segment_start_index + tf.range(order, dtype=segment_start_index.dtype)) in_bounds = tf.where(fix_indices < segment_end_index) # Keep only the ones in bounds. fix_indices = tf.reshape(tf.gather_nd(fix_indices, in_bounds), [-1, 1]) needs_fix = tf.scatter_nd( fix_indices, # Unfortunately, scatter_nd doesn't support bool on GPUs so we need to # do ints here and then convert to bool. tf.reshape(tf.ones_like(fix_indices, dtype=tf.int32), [-1]), shape=tf.shape(x)) # If exclusive is False, then needs_fix means we need to replace the values # in raw_diffs at those locations with the values in x. needs_fix = tf.cast(needs_fix, dtype=tf.bool) if not exclusive: return tf.where(needs_fix, x, raw_diffs) # If exclusive is True, we have to be more careful. The raw_diffs # computation has removed the first 'order' elements. After removing the # corresponding elements from needs_fix, we use it to remove the elements # from raw_diffs. return tf.boolean_mask(raw_diffs, tf.logical_not(needs_fix[order:]))
def add_entity_tokens( text_ids: tf.Tensor, text_mask: tf.Tensor, mention_mask: tf.Tensor, mention_batch_positions: tf.Tensor, mention_start_positions: tf.Tensor, mention_end_positions: tf.Tensor, new_length: int, mlm_target_positions: Optional[tf.Tensor] = None, mlm_target_weights: Optional[tf.Tensor] = None, entity_start_token_id: int = default_values.ENTITY_START_TOKEN, entity_end_token_id: int = default_values.ENTITY_END_TOKEN, ) -> Dict[str, tf.Tensor]: """Adds entity start / end tokens around mentions. Inserts `entity_start_token_id` and `entity_end_token_id` tokens around each mention and update mention_start_positions / mention_end_positions to point to these tokens. New text length will be `new_length` and texts will be truncated if nessesary. If a mention no longer fits into the new text, its mask (`mention_mask`) will be set to 0. The function can also update MLM position and weights (`mlm_target_positions` and `mlm_target_weights`) if these arguments are provided. Similarly to mentions, if MLM position no longer fits into the new text, its mask (`mlm_target_weights`) will be set to 0. Args: text_ids: [seq_length] tensor with token ids. text_mask: [seq_length] tensor with 1s for tokens and 0 for padding. mention_mask: [n_mentions] mask indicating whether a mention is a padding. mention_batch_positions: [n_mentions] sample ID of a mention in the batch. mention_start_positions: [n_mentions] position of a mention first token within a sample. mention_end_positions: [n_mentions] position of a mention last token within a sample. new_length: new length of text after entity tokens will be added. mlm_target_positions: [batch_size, max_mlm_targets] positions of tokens to be used for MLM task. mlm_target_weights: [batch_size, max_mlm_targets] mask indicating whether `mlm_target_positions` is a padding. entity_start_token_id: token to be used as entity start token. entity_end_token_id: token to be used as entity end token. Returns: New text_ids and text_mask, updated mentions positions including mention_start_positions, mention_end_positions and mention_mask. Returns updated mlm_target_positions and mlm_target_weights if they were provided as arguments. """ batch_size = tf.shape(text_ids)[0] old_length = tf.shape(text_ids)[1] new_shape = (batch_size, new_length) mentions_fit_mask = compute_which_mentions_fit_with_entity_tokens( mention_mask, mention_batch_positions, mention_start_positions, mention_end_positions, batch_size, old_length, new_length, ) # Ignore mentions that does not fit into new texts. new_mention_mask = mention_mask * mentions_fit_mask mention_start_positions = mention_start_positions * new_mention_mask mention_end_positions = mention_end_positions * new_mention_mask positions = compute_positions_shift_with_entity_tokens( new_mention_mask, mention_batch_positions, mention_start_positions, mention_end_positions, batch_size, old_length) def get_2d_index(positions: tf.Tensor) -> tf.Tensor: return _get_2d_index(mention_batch_positions, positions) def get_new_positions(old_positions: tf.Tensor) -> tf.Tensor: index_2d = get_2d_index(old_positions) return tf.gather_nd(positions, index_2d) new_mention_start_positions = get_new_positions( mention_start_positions) - 1 new_mention_start_positions = new_mention_start_positions * new_mention_mask new_mention_end_positions = get_new_positions(mention_end_positions) + 1 new_mention_end_positions = new_mention_end_positions * new_mention_mask if mlm_target_positions is not None: if mlm_target_weights is None: raise ValueError('`mlm_target_weights` must be specified if ' '`mlm_target_positions` is provided.') mlm_target_positions = tf.gather(positions, mlm_target_positions, batch_dims=1) mlm_target_positions_within_len = tf.less(mlm_target_positions, new_length) mlm_target_positions_within_len = tf.cast( mlm_target_positions_within_len, mlm_target_weights.dtype) mlm_target_weights = mlm_target_weights * mlm_target_positions_within_len # Zero-out positions for pad MLM targets mlm_target_positions = mlm_target_positions * mlm_target_weights # Cut texts that are longer than `new_length` text_within_new_length = tf.less(positions, new_length) text_ids = text_ids * tf.cast(text_within_new_length, text_ids.dtype) text_mask = text_mask * tf.cast(text_within_new_length, text_mask.dtype) positions = tf.minimum(positions, new_length - 1) # Prepare 2D index for tokens positions in the next text_ids and text_mask. # Note that we use flat 2D index and flat values # (e.g. `tf.reshape(text_ids, [-1])`) since `tf.scatter_nd` does not support # batch dimension. batch_positions = _batched_range(old_length, batch_size, 1, positions.dtype) batch_positions = tf.reshape(batch_positions, [-1]) text_index_2d = _get_2d_index(batch_positions, tf.reshape(positions, [-1])) new_text_ids = tf.scatter_nd(text_index_2d, tf.reshape(text_ids, [-1]), new_shape) new_text_mask = tf.scatter_nd(text_index_2d, tf.reshape(text_mask, [-1]), new_shape) # Insert entity start / end tokens into the new text_ids and text_mask. new_mention_start_positions_2d = get_2d_index(new_mention_start_positions) new_mention_end_positions_2d = get_2d_index(new_mention_end_positions) new_text_ids = tf.tensor_scatter_nd_add( new_text_ids, new_mention_start_positions_2d, new_mention_mask * entity_start_token_id) new_text_ids = tf.tensor_scatter_nd_add( new_text_ids, new_mention_end_positions_2d, new_mention_mask * entity_end_token_id) new_mention_mask = tf.cast(new_mention_mask, dtype=text_mask.dtype) new_text_mask = tf.tensor_scatter_nd_add(new_text_mask, new_mention_start_positions_2d, new_mention_mask) new_text_mask = tf.tensor_scatter_nd_add(new_text_mask, new_mention_end_positions_2d, new_mention_mask) features = { 'text_ids': new_text_ids, 'text_mask': new_text_mask, 'mention_start_positions': new_mention_start_positions, 'mention_end_positions': new_mention_end_positions, 'mention_mask': new_mention_mask, } if mlm_target_positions is not None: features['mlm_target_weights'] = mlm_target_weights features['mlm_target_positions'] = mlm_target_positions return features
def _update_batch(ind, b_update, b=None, batch_shape=None): """Updates a batch of `i`, `j`, `q1[i]` or `q0[j]`.""" updates = tf.gather_nd(b_update, ind) if b is None: return tf.scatter_nd(ind, updates, batch_shape) return tf.tensor_scatter_nd_update(b, ind, updates)
def _cap_positives_mask(untiled_mask, diagonal_mask, num_views, positives_cap): r"""Cap positives in the provided untiled_mask. 'positives_cap' specifies the maximum number of positives *other* than augmentations of the anchor. Positives will be evenly sampled from all views. Args: untiled_mask: Tensor of shape [local_batch_size, global_batch_size] that has entry (r, c) == 1 if feature entries in rows r and c are from the same class. Else (r, c) == 0. diagonal_mask: Tensor with the same shape as `untiled_mask`. When local_batch_size == global_batch_size this is just an identity matrix. Otherwise, it is an identity matrix of size `local_batch_size` that is padded with 0's in the 2nd dimension to match the target shape. This is used to indicate where the anchor views exist in the global batch of views. num_views: Integer number of total views. positives_cap: Integer maximum number of positives *other* than augmentations of anchor. Infinite if < 0. Must be multiple of num_views. Including augmentations, a maximum of (positives_cap + num_views - 1) positives is possible. This parameter modifies the contrastive numerator by selecting which positives are present in the summation, and which positives contribure to the denominator if denominator_mode == enums.LossDenominatorMode.ALL. Returns: A tf.Tensor with the modified `untiled_mask`. """ untiled_mask_no_diagonal = tf.math.minimum(untiled_mask, 1. - diagonal_mask) untiled_positives_per_anchor = positives_cap // num_views # Grabs top-k positives from each row in the mask. Can end up with negatives # incorrectly marked as positives if fewer than `untiled_positives_per_anchor` # exist in any row of `untiled_mask_no_diagonal`. However, these false # positives wil be masked out before the function returns. _, top_k_col_idx = tf.math.top_k(untiled_mask_no_diagonal, untiled_positives_per_anchor) top_k_row_idx = tf.expand_dims(tf.range(tf.shape(untiled_mask)[0]), axis=1) # Construct |top_k_idx|, a tensor of shape # [untiled_positives_per_anchor * local_batch_size, 2]. Each row represents # the 2D index in a # [local_batch_size * num_anchor_views, global_batch_size * num_views] size # tensor which holds a positive; all others are negatives. top_k_idx = tf.reshape( tf.stack([ tf.tile(top_k_row_idx, (1, untiled_positives_per_anchor)), top_k_col_idx ], axis=-1), (-1, 2)) # Construct |untiled_mask|. Sets positives to 1 according to top_k_idx # above. untiled_mask_capped = tf.scatter_nd( top_k_idx, tf.ones(shape=tf.shape(top_k_idx)[0], dtype=untiled_mask_no_diagonal.dtype), untiled_mask_no_diagonal.shape) untiled_mask_capped = tf.math.maximum(untiled_mask_capped, diagonal_mask) return untiled_mask * untiled_mask_capped
def prepare_grid(*, times, time_step, dtype, num_time_steps=None, times_grid=None): """Prepares grid of times for path generation. Args: times: Rank 1 `Tensor` of increasing positive real values. The times at which the path points are to be evaluated. time_step: Rank 0 real `Tensor`. Maximal distance between points in resulting grid. dtype: `tf.Dtype` of the input and output `Tensor`s. num_time_steps: Number of points on the grid. If suppied, a uniform grid is constructed for `[time_step, times[-1] - time_step]` consisting of max(0, num_time_steps - len(times)) points that is then concatenated with times. This parameter guarantees the number of points on the time grid is `max(len(times), num_time_steps)` and that `times` are included to the grid. Default value: `None`, which means that a uniform grid is created. containing all points from 'times` and the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`. times_grid: An optional rank 1 `Tensor` representing time discretization grid. If `times` are not on the grid, then the nearest points from the grid are used. Default value: `None`, which means that times grid is computed using `time_step` and `num_time_steps`. Returns: Tuple `(all_times, mask, time_indices)`. `all_times` is a 1-D real `Tensor`. If `num_time_steps` is supplied the shape of the output is `max(num_time_steps, len(times))`. Otherwise consists of all points from 'times` and the uniform grid of points between `[0, times[-1]]` with grid size equal to `time_step`. `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing which elements of 'all_times' correspond to THE values from `times`. Guarantees that times[0]=0 and mask[0]=False. `time_indices`. An integer `Tensor` of the same shape as `times` indicating `times` indices in `all_times`. """ if times_grid is None: if num_time_steps is None: all_times, time_indices = _grid_from_time_step( times=times, time_step=time_step, dtype=dtype) else: all_times, time_indices = _grid_from_num_times( times=times, time_step=time_step, num_time_steps=num_time_steps) else: all_times = times_grid time_indices = tf.searchsorted(times_grid, times) # Adjust indices to bring `times` closer to `times_grid`. times_diff_1 = tf.gather(times_grid, time_indices) - times times_diff_2 = tf.gather( times_grid, tf.math.maximum(time_indices-1, 0)) - times time_indices = tf.where( tf.math.abs(times_diff_2) > tf.math.abs(times_diff_1), time_indices, tf.math.maximum(time_indices - 1, 0)) # Create a boolean mask to identify the iterations that have to be recorded. # Use `tf.scatter_nd` because it handles duplicates. Also we first create # an int64 Tensor and then create a boolean mask becase scatter_nd with # booleans is currently not supported on GPUs. mask = tf.scatter_nd( indices=tf.expand_dims(tf.cast(time_indices, dtype=tf.int64), axis=1), updates=tf.fill(tf.shape(times), 1), shape=tf.shape(all_times, out_type=tf.int64)) mask = tf.where(mask > 0, True, False) return all_times, mask, time_indices
def _delete_tf(tensor, idx, axis=0): """Deletes from a tensor along an axis at the given index.""" n = tf.shape(tensor)[axis] t = tf.ones_like(idx, dtype=tf.bool) m = ~tf.scatter_nd(tf.expand_dims(idx, 1), t, [n]) return tf.boolean_mask(tensor, m, axis=axis)
def sparse_to_dense_1d(sparse_values: tf.Tensor, seq_length: int): """Convert sparse tensor ([0, 1, 4]) to dense tensor ([1, 1, 0, 0, 1]).""" updates = tf.fill(tf.shape(sparse_values), value=1) updates = tf.cast(updates, sparse_values.dtype) return tf.scatter_nd(tf.expand_dims(sparse_values, 1), updates, [seq_length])