def tf_hilbert(x, axis=-1): '''Performs 1d hilbert similar to scipy''' # Change axes to be most inner for fft axis = tf.constant(axis) if axis < 0: axis = tf.rank(x) + axis axes = tf.range(tf.rank(x)) axes = tf.math.mod(axes - tf.reduce_max(axes) + axis, tf.size(axes)) x = tf.transpose(x, perm=axes) # Apply fft x = tf.cast(x, dtype=tf.complex64) Xf = tf.signal.fft(x) # Create 2U N = tf.shape(Xf)[-1] h = tf.cast(tf.ones([N // 2 + 1]) * 2, Xf.dtype) if tf.math.mod(N, 2) == 0: h = tf.tensor_scatter_nd_update(h, [[0], [tf.size(h) - 1]], [1, 1]) else: h = tf.tensor_scatter_nd_update(h, [[0]], [1]) h = tf.concat([h, tf.zeros(N - tf.size(h), dtype=h.dtype)], axis=0) # Apply ifft and hilbert x = tf.signal.ifft(Xf * h) # Change axes back x = tf.transpose(x, perm=tf.argsort(axes)) return x
def _parse_record(self, record) -> NestedMap: """Reads and parses a single record.""" p = self.params name_to_features = { 'input_ids': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'masked_lm_positions': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_ids': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_weights': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32), } example = tf.io.parse_single_example(record, name_to_features) mask_length = tf.cast( tf.reduce_sum(example['masked_lm_weights']), dtype=tf.int32) masked_lm_positions = tf.slice(example['masked_lm_positions'], [0], [mask_length]) masked_lm_ids = tf.cast( tf.slice(example['masked_lm_ids'], [0], [mask_length]), dtype=tf.int32) ret = py_utils.NestedMap() ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32) # Get back non-masked, original ids. ret.labels = tf.tensor_scatter_nd_update( tensor=ret.masked_ids, indices=tf.reshape(masked_lm_positions, [-1, 1]), updates=masked_lm_ids) ret.masked_pos = tf.tensor_scatter_nd_update( tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32), indices=tf.reshape(masked_lm_positions, [-1, 1]), updates=tf.ones_like(masked_lm_ids, dtype=tf.float32)) ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32) first_eos_idx = tf.where(tf.math.equal(ret.labels, p.eos_token_id))[0][0] def remove_first_eos(x): # We remove the element at position `first_eos_idx`, and pad with 0 # to keep length unchanged. zero = tf.constant(0, shape=(1,), dtype=x.dtype) return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero], axis=0) ret = ret.Transform(remove_first_eos) ret.paddings = 1.0 - ret.segment_ids pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32) ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32) if p.remask: new_masked_ids, new_masked_pos = self.mlm.FProp(None, ret.labels, ret.paddings) ret.masked_ids = new_masked_ids ret.masked_pos = new_masked_pos return ret
def loop_body(index, samples): """Loop for iterative pixel sampling. Args: index: 0D `Tensor` of type `int32`. Index of the current pixel. samples: 4D `Tensor`. Images with pixels sampled in raster order, up to pixel `[index]`, with dimensions `[batch_size, height, width, num_channels]`. Returns: samples: 4D `Tensor`. Images with pixels sampled in raster order, up to and including pixel `[index]`, with dimensions `[batch_size, height, width, num_channels]`. """ inputs = samples if conditional_input is None else [samples, h] params = self.network(inputs, training=training) samples_new = self._sample_channels(*params, seed=seed) # Update the current pixel samples = tf.transpose(samples, [1, 2, 3, 0]) samples_new = tf.transpose(samples_new, [1, 2, 3, 0]) row, col = index // image_width, index % image_width updates = samples_new[row, col, ...][tf.newaxis, ...] samples = tf.tensor_scatter_nd_update(samples, [[row, col]], updates) samples = tf.transpose(samples, [3, 0, 1, 2]) return index + 1, samples
def contrastive_loss(similarity_matrix, metric_values, temperature, coupling_temperature=1.0, use_coupling_weights=True): """Contrative Loss with soft coupling.""" logging.info('Using alternative contrastive loss.') metric_shape = tf.shape(metric_values) similarity_matrix /= temperature neg_logits1 = similarity_matrix col_indices = tf.cast(tf.argmin(metric_values, axis=1), dtype=tf.int32) pos_indices1 = tf.stack( (tf.range(metric_shape[0], dtype=tf.int32), col_indices), axis=1) pos_logits1 = tf.gather_nd(similarity_matrix, pos_indices1) if use_coupling_weights: metric_values /= coupling_temperature coupling = tf.exp(-metric_values) pos_weights1 = -tf.gather_nd(metric_values, pos_indices1) pos_logits1 += pos_weights1 negative_weights = tf.math.log((1.0 - coupling) + EPS) neg_logits1 += tf.tensor_scatter_nd_update(negative_weights, pos_indices1, pos_weights1) neg_logits1 = tf.math.reduce_logsumexp(neg_logits1, axis=1) return tf.reduce_mean(neg_logits1 - pos_logits1)
def _inverse_event_shape_tensor(self, output_shapes): """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. Args: output_shapes: An iterable of `Tensor`, `int32` vectors indicating event-shapes passed into `inverse` function. The length of the iterable must be equal to the number of splits. Returns: inverse_event_shape_tensor: `Tensor`, `int32` vector indicating event-portion shape after applying `inverse`. """ # Validate `output_shapes` statically if possible and get assertions. is_validated = self._validate_output_shapes([ tensorshape_util.constant_value_as_shape(s) for s in output_shapes ]) if is_validated or not self.validate_args: assertions = [] else: assertions = self._validate_output_shape_tensors(output_shapes) with tf.control_dependencies(assertions): total_size = tf.reduce_sum([t[self.axis] for t in output_shapes]) inverse_event_shape = tf.tensor_scatter_nd_update( output_shapes[0], [[prefer_static.rank_from_shape(output_shapes[0]) + self.axis] ], [total_size]) return tf.identity( tf.convert_to_tensor(inverse_event_shape, dtype_hint=tf.int32, name='inverse_event_shape'))
def _get_q_slice(q, k, ind, b=None, batch_shape=None): """Returns `q1[i]` or `q0[j]` for a batch of indices `i` or `j`.""" q_ind = tf.concat([ind, tf.expand_dims(tf.gather_nd(k, ind), -1)], axis=1) b_updates = tf.gather_nd(q, q_ind) if b is None: return tf.scatter_nd(ind, b_updates, batch_shape) return tf.tensor_scatter_nd_update(b, ind, b_updates)
def inplace_update_i(inp_tensor, updates, i): """Inplace update a tensor. B: batch_size, L: tensor length.""" batch_size = inp_tensor.shape[0] indices = tf.stack([ tf.range(batch_size, dtype=tf.int32), tf.fill([batch_size], tf.cast(i, tf.int32)) ], axis=-1) return tf.tensor_scatter_nd_update(inp_tensor, indices, updates)
def _update_i(tensor_BxNxT, updates_BxN, i): B, N, T = tensor_BxNxT.shape tensor_BNxT = tf.reshape(tensor_BxNxT, [-1, T]) updates_BN = tf.reshape(updates_BxN, [-1]) batch_BN = tf.range(B * N, dtype=tf.int32) i_BN = tf.fill([B * N], i) ind_BNx2 = tf.stack([batch_BN, i_BN], axis=-1) tensor_BNxT = tf.tensor_scatter_nd_update(tensor_BNxT, ind_BNx2, updates_BN) return tf.reshape(tensor_BNxT, [B, N, T])
def moveaxis(a, source, destination): # pylint: disable=missing-docstring """Raises ValueError if source, destination not in (-ndim(a), ndim(a)).""" if not source and not destination: return a a = asarray(a).data if isinstance(source, int): source = (source, ) if isinstance(destination, int): destination = (destination, ) a_rank = utils._maybe_static(tf.rank(a)) # pylint: disable=protected-access def _correct_axis(axis, rank): if axis < 0: return axis + rank return axis source = tuple(_correct_axis(axis, a_rank) for axis in source) destination = tuple(_correct_axis(axis, a_rank) for axis in destination) if a.shape.rank is not None: perm = [i for i in range(a_rank) if i not in source] for dest, src in sorted(zip(destination, source)): assert dest <= len(perm) perm.insert(dest, src) else: r = tf.range(a_rank) def _remove_indices(a, b): """Remove indices (`b`) from `a`.""" items = tf.unstack(tf.sort(tf.stack(b)), num=len(b)) i = 0 result = [] for item in items: result.append(a[i:item]) i = item + 1 result.append(a[i:]) return tf.concat(result, 0) minus_sources = _remove_indices(r, source) minus_dest = _remove_indices(r, destination) perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources, [a_rank]) perm = tf.tensor_scatter_nd_update(perm, tf.expand_dims(destination, 1), source) a = tf.transpose(a, perm) return utils.tensor_to_ndarray(a)
def _get_endpoint_a(i, j, q1_i, q0_j, batch_shape): """Determine the beginning of the interval, `a`.""" # if i < 0: a = q0[j] i_lt_0 = tf.less(i, 0) ind = tf.where(i_lt_0) a_update = tf.gather_nd(q0_j, ind) a = tf.scatter_nd(ind, a_update, batch_shape) # elif j < 0: a = q1[i] j_lt_0 = tf.less(j, 0) ind = tf.where(j_lt_0) a_update = tf.gather_nd(q1_i, ind) a = tf.tensor_scatter_nd_update(a, ind, a_update) # else: a = max(q0[j], q1[i]) ind = tf.where(~(i_lt_0 | j_lt_0)) q_max = tf.maximum(q0_j, q1_i) a_update = tf.gather_nd(q_max, ind) a = tf.tensor_scatter_nd_update(a, ind, a_update) return a
def swapaxes(a, axis1, axis2): # pylint: disable=missing-docstring a = asarray(a) a_rank = tf.rank(a) if axis1 < 0: axis1 += a_rank if axis2 < 0: axis2 += a_rank perm = tf.range(a_rank) perm = tf.tensor_scatter_nd_update(perm, [[axis1], [axis2]], [axis2, axis1]) a = tf.transpose(a, perm) return utils.tensor_to_ndarray(a)
def _masked_tensor_scatter_nd_update(tensor, indices, updates, mask=None): """Performs tensor_scatter_nd_update with masked updates.""" if mask is None: return tf.tensor_scatter_nd_update(tensor, indices, updates) # We have to handle two cases: (1) all updates are masked and (2) there is # at least one update not masked. We do not want to unnessesarily recreate the # cache and to support TPU we cannot use condtional statements. pred = tf.reduce_any(tf.cast(mask, tf.bool)) pred_indices = tf.cast(pred, indices.dtype) pred_updates = tf.cast(pred, updates.dtype) indices_mask = tf.expand_dims(tf.cast(mask, dtype=indices.dtype), axis=1) updates_mask = _get_broadcastable_mask(mask, updates) # If there is at least one unmasked element we will get it here. We will # replace all masked indices and updates with this element. _, not_masked = tf.math.top_k(indices_mask[:, 0]) not_masked = not_masked[0] indices = indices * indices_mask + indices[not_masked] * (1 - indices_mask) updates = updates * updates_mask + updates[not_masked] * (1 - updates_mask) # If all elements are masked, then indices will become all zero and we # "update" the tensor with the value at the zeroth index, effectively not # changing the tensor. indices = pred_indices * indices updates = pred_updates * updates + (1 - pred_updates) * tensor[0] return tf.tensor_scatter_nd_update(tensor, indices, updates)
def body(i, triangular_factor): column_head = triangular_factor[..., i, i, tf.newaxis] column_tail = triangular_factor[..., i+1:, i] rescaled_tail = column_tail / column_head triangular_factor = tf.tensor_scatter_nd_update( triangular_factor, slix[..., i+1:, i], rescaled_tail) triangular_factor = tf.tensor_scatter_nd_sub( triangular_factor, slix[..., i+1:, i+1:], tf.linalg.band_part( tf.einsum('...i,...j->...ij', column_tail, rescaled_tail), num_lower=-1, num_upper=0)) return i+1, triangular_factor
def rot90(m, k=1, axes=(0, 1)): # pylint: disable=missing-docstring m_rank = tf.rank(m) ax1, ax2 = utils._canonicalize_axes(axes, m_rank) # pylint: disable=protected-access k = k % 4 if k == 0: return m elif k == 2: return flip(flip(m, ax1), ax2) else: perm = tf.range(m_rank) perm = tf.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1]) if k == 1: return transpose(flip(m, ax2), perm) else: return flip(transpose(m, perm), ax2)
def call(self, inputs): value, index = inputs if self.cache.shape == inputs[0].shape: self.cache = value return value shape = self.cache.shape.as_list() num_index_axes = index.shape[0] num_batch_axes = self.num_batch_axes num_feature_axes = len(shape) - num_index_axes - num_batch_axes features_shape = shape[num_batch_axes + num_index_axes:] batch_shape = shape[:num_batch_axes] value_index_shape = tf.shape(value)[num_batch_axes:-num_feature_axes] if tf.reduce_max(value_index_shape) > 1: # This is a block update starting at index. value_ranges = [] for i, s in enumerate(tf.unstack(value_index_shape)): curr_range = tf.range(index[i], index[i] + s) value_ranges.append(curr_range) batch_ranges = [tf.range(s) for s in batch_shape] mesh = tf.meshgrid(*(batch_ranges + value_ranges), indexing='ij') indices = tf.stack(mesh, axis=-1) indices = tf.reshape(indices, [-1, num_index_axes + num_batch_axes]) else: # This is a single update at index position. batch_ranges = [tf.range(s) for s in batch_shape] mesh = tf.meshgrid(*batch_ranges, indexing='ij') batch_indices = tf.stack(mesh, axis=-1) batch_indices = tf.reshape(batch_indices, [-1, num_batch_axes]) # Add leading axes to nd-index and tile to get batched indices. shape_indices = tf.reshape(index, [1] * num_batch_axes + [-1]) shape_indices = tf.tile(shape_indices, batch_shape + [1]) shape_indices = tf.reshape(shape_indices, [-1, num_index_axes]) indices = tf.concat([batch_indices, shape_indices], axis=-1) # We need to squeeze nd-axes from value before updating. value = tf.reshape(value, [-1] + features_shape) self.cache = tf.tensor_scatter_nd_update(self.cache, indices, value) return self.cache
def test_arrays_all_close(self, dtype): """Tests the _arrays_all_close() helper function.""" with self.subTest("Expected pass"): a = tf.ones([4, 2, 3], dtype) b = a * 0.99 atol = a * 0.02 gmb_utils.arrays_all_close(self, a, b, atol) with self.subTest("Shape mismatch"): a = tf.ones([4, 2, 3], dtype) b = tf.ones([4, 1, 3], dtype) * 0.99 atol = a * 0.02 with self.assertRaises(ValueError): gmb_utils.arrays_all_close(self, a, b, atol) with self.subTest("Values not close"): a = tf.ones([4, 2, 3], dtype) b = tf.ones([4, 2, 3], dtype) * 0.99 c = tf.tensor_scatter_nd_update(b, [(1, 1, 1)], [0.95]) atol = a * 0.02 with self.assertRaises(ValueError): gmb_utils.arrays_all_close(self, a, c, atol)
def _set_vector_index_unbatched(v, idx, x): """Mutation-free equivalent of `v[idx] = x.""" return tf.tensor_scatter_nd_update(v, indices=[[idx]], updates=[x])
def _safe_tensor_scatter_nd_update(tensor, indices, updates): if tensorshape_util.num_elements(tensor.shape) == 0: return tensor return tf.tensor_scatter_nd_update(tensor, indices, updates)
def set_negative_scores(scores, indices): indices_2d = tf.stack( [tf.range(bsz, dtype=indices.dtype), indices], axis=1) return tf.tensor_scatter_nd_update( scores, indices_2d, tf.fill(tf.shape(indices), -1.0))
def call(self, x, training=False): x_flat = tf.reshape(x, shape=(-1, self.depth)) # Split each input vector into one segment per head. x_flat_split = tf.split(x_flat, self.num_heads, axis=1) x_flat = tf.concat(x_flat_split, axis=0) if training: # Figure out which centroids we want to keep, and which we want to # restart. n = x_flat.shape[0] keep = self.counts * self.k > self.restart_threshold * n restart = tf.math.logical_not(keep) # Replace centroids to restart with elements from the batch, using samples # from a uniform distribution as a fallback in case we need to restart # more centroids than we have elements in the batch. restart_idx = tf.squeeze(tf.where(restart), -1) n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0]) e_restart = tf.tensor_scatter_nd_update( tf.random.uniform([self.k, self.depth // self.num_heads]), tf.expand_dims(restart_idx[:n_replace], 1), tf.random.shuffle(x_flat)[:n_replace]) # Compute the values of the centroids we want to keep by dividing the # summed vectors by the corresponding counts. e = tf.where( tf.expand_dims(keep, 1), tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)), e_restart) else: # If not training, just use the centroids as is with no restarts. e = tf.math.divide_no_nan(self.sums, tf.expand_dims(self.counts, 1)) # Compute distance between each input vector and each cluster center. distances = (tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) - 2 * tf.matmul(x_flat, tf.transpose(e)) + tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0)) # Find nearest cluster center for each input vector. c = tf.argmin(distances, axis=1) # Quantize input vectors with straight-through estimator. z = tf.nn.embedding_lookup(e, c) z_split = tf.split(z, self.num_heads, axis=0) z = tf.concat(z_split, axis=1) z = tf.reshape(z, tf.shape(x)) z = x + tf.stop_gradient(z - x) if training: # Compute cluster counts and vector sums over the batch. oh = tf.one_hot(indices=c, depth=self.k) counts = tf.reduce_sum(oh, axis=0) sums = tf.matmul(oh, x_flat, transpose_a=True) # Apply exponential moving average to cluster counts and vector sums. self.counts.assign_sub((1 - self.gamma) * (self.counts - counts)) self.sums.assign_sub((1 - self.gamma) * (self.sums - sums)) c_split = tf.split(c, self.num_heads, axis=0) c = tf.stack(c_split, axis=1) c = tf.reshape(c, tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0)) return z, c
def collater_fn(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: batch = mm_collater_fn(batch) retrieve_masked = config.get('retrieve_masked', False) # Subselect mentions for which to retrieve corresponding memory. # We want to sample mentions which are linked, not masked, and not padded. scores = tf.random.uniform( tf.shape(batch['mention_target_is_masked'])) + 2 * tf.cast( batch['mention_target_weights'], tf.float32) if not retrieve_masked: scores -= tf.cast(batch['mention_target_is_masked'], tf.float32) _, mention_target_retrieval_indices = tf.math.top_k( scores, k=max_retrieval_indices) mention_retrieval_indices = tf.gather( batch['mention_target_indices'], mention_target_retrieval_indices) retrieval_mention_mask = tf.gather( batch['mention_target_weights'], mention_target_retrieval_indices) # set weight to 0 for masked retrievals if we do not want to include these if not retrieve_masked: retrieval_mention_mask *= tf.gather( 1 - tf.cast(batch['mention_target_is_masked'], tf.int32), mention_target_retrieval_indices) retrieval_mention_start_positions = tf.gather( batch['mention_start_positions'], mention_retrieval_indices) retrieval_text_identifiers = tf.gather(batch['text_identifiers'], mention_retrieval_indices) retrieval_mention_hash = mention_preprocess_utils.modified_cantor_pairing( tf.cast(retrieval_mention_start_positions, tf.int64), retrieval_text_identifiers) retrieval_mention_hash = tf.cast(retrieval_mention_hash, tf.int32) retrieval_mention_sort_ids = tf.searchsorted( memory_hash_sorted, retrieval_mention_hash) # Searchsorted does not check whether value is present in array, just # finds insertion point. Here we check and set to default retrieval if not # present. hash_not_present_mask = tf.not_equal( retrieval_mention_hash, tf.gather(memory_hash_sorted, retrieval_mention_sort_ids)) hash_not_present = tf.where(hash_not_present_mask) update_values = tf.fill((tf.shape(hash_not_present)[0], ), tf.shape(hash_sorted_idx)[0] - 1) retrieval_mention_sort_ids = tf.tensor_scatter_nd_update( retrieval_mention_sort_ids, hash_not_present, update_values) # Set mask to 0 if no mention is found batch['retrieval_mention_mask'] = retrieval_mention_mask * ( 1 - tf.cast(hash_not_present_mask, tf.int32)) retrieval_mention_ids = tf.gather(hash_sorted_idx, retrieval_mention_sort_ids) retrieval_mention_values = tf.gather(memory_table, retrieval_mention_ids) # Match passage entity_ids with memory entity ids as sanity check. if memory_entity_pattern: retrieval_memory_entity_ids = tf.gather( memory_entity_ids, retrieval_mention_ids) retrieval_passage_entity_ids = tf.gather( tf.cast(batch['mention_target_ids'], tf.int32), mention_target_retrieval_indices) entity_does_not_match = tf.not_equal( retrieval_memory_entity_ids, retrieval_passage_entity_ids) batch['entity_does_not_match'] = tf.logical_and( entity_does_not_match, tf.cast(batch['retrieval_mention_mask'], tf.bool)) batch['retrieval_mention_values'] = retrieval_mention_values batch['retrieval_mention_scores'] = tf.ones_like( batch['retrieval_mention_mask']) batch['retrieval_mention_batch_positions'] = tf.gather( batch['mention_batch_positions'], mention_retrieval_indices) batch['retrieval_mention_start_positions'] = retrieval_mention_start_positions # pylint: disable=line-too-long batch['retrieval_mention_end_positions'] = tf.gather( batch['mention_end_positions'], mention_retrieval_indices) batch['mention_retrieval_indices'] = mention_retrieval_indices return batch
def _update_batch(ind, b_update, b=None, batch_shape=None): """Updates a batch of `i`, `j`, `q1[i]` or `q0[j]`.""" updates = tf.gather_nd(b_update, ind) if b is None: return tf.scatter_nd(ind, updates, batch_shape) return tf.tensor_scatter_nd_update(b, ind, updates)
def _replace_at_index(x, index, replacement): """Replaces an element at supplied index.""" return tf.tensor_scatter_nd_update(x, [[index]], [replacement])
def scatter_update_2D_slice(self, tensor, indices, updates): return tf.tensor_scatter_nd_update(tensor, indices, updates)
def _loop_build_sub_tree( self, direction, log_slice_sample, iter_, prev_tree_state, candidate_tree_state, continue_tree_previous, trace_arrays): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) for state in prev_tree_state.state] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[tf.where(direction, ss, -ss) for direction, ss in zip( directions_expanded, self.step_size)], num_steps=self.unrolled_leapfrog_steps) [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step to avoid # using tf.cond index = iter_ // 2 if USE_RAGGED_TENSOR: write_index_ = self.write_instruction[index] else: write_index_ = tf.switch_case(index, self.write_instruction) write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: trace_arrays = TraceArrays( momentum_swap=[ old.write(write_index, new) for old, new in zip(trace_arrays.momentum_swap, next_momentum_parts)], state_swap=[ old.write(write_index, new) for old, new in zip(trace_arrays.state_swap, next_state_parts)]) else: trace_arrays = TraceArrays( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( trace_arrays.momentum_swap, next_momentum_parts)], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( trace_arrays.state_swap, next_state_parts)]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) if USE_RAGGED_TENSOR: no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions_expanded, trace_arrays, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, int_iter, directions_expanded, trace_arrays, next_momentum_parts, next_state_parts) branch_excution = {x: functools.partial(f, x) for x in range(len(self.read_instruction))} no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) valid_candidate = log_slice_sample <= energy # Uniform sampling on the trajectory within the subtree sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE) weight_sum = candidate_tree_state.weight + sample_weight log_accept_thresh = tf.math.log( tf.cast(sample_weight, tf.float32) / tf.cast(weight_sum, tf.float32)) log_accept_thresh = tf.where( tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform( shape=[batch_size], dtype=tf.float32, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip(next_target_grad_parts, candidate_tree_state.target_grad_parts) ], weight=weight_sum) not_divergent = log_slice_sample - energy < self.max_energy_diff continue_tree = not_divergent & no_u_turns_within_tree continue_tree_next = continue_tree_previous & continue_tree return ( iter_ + 1, next_tree_state, next_candidate_tree_state, continue_tree_next, trace_arrays, )
def loop_tree_doubling(self, step_size, log_slice_sample, init_energy, momentum_state_memory, iter_, initial_step_state, initial_step_metastate): """Main loop for tree doubling.""" with tf.name_scope('loop_tree_doubling'): batch_size = prefer_static.size(init_energy) direction = tf.cast(tf.random.uniform(shape=[batch_size], minval=0, maxval=2, dtype=tf.int32, seed=self._seed_stream()), dtype=tf.bool) left_right_index = tf.concat([ tf.cast(direction, tf.int32)[..., tf.newaxis], tf.range(batch_size, dtype=tf.int32)[..., tf.newaxis] ], axis=1) tree_start_states = tf.nest.map_structure( # Alternatively: `lambda v: tf.where(direction, v[1], v[0])` lambda v: tf.gather_nd(v, left_right_index), initial_step_state) directions_expanded = [ _expand_dims_under_batch_dim(direction, prefer_static.rank(state)) for state in tree_start_states.state ] integrator = leapfrog_impl.SimpleLeapfrogIntegrator( self.target_log_prob_fn, step_sizes=[ tf.where(direction, ss, -ss) for direction, ss in zip(directions_expanded, step_size) ], num_steps=self.unrolled_leapfrog_steps) [ candidate_tree_state, tree_final_states, final_not_divergence, continue_tree_final, energy_diff_tree_sum, leapfrogs_taken, ] = self._build_sub_tree( directions_expanded, integrator, log_slice_sample, init_energy, # num_steps_at_this_depth = 2**iter_ = 1 << iter_ tf.bitwise.left_shift(1, iter_), tree_start_states, initial_step_metastate.continue_tree, initial_step_metastate.not_divergence, momentum_state_memory) last_candidate_state = initial_step_metastate.candidate_state tree_weight = candidate_tree_state.weight if MULTINOMIAL_SAMPLE: weight_sum = log_add_exp(tree_weight, last_candidate_state.weight) log_accept_thresh = tree_weight - last_candidate_state.weight else: weight_sum = tree_weight + last_candidate_state.weight log_accept_thresh = tf.math.log( tf.cast(tree_weight, tf.float32) / tf.cast(last_candidate_state.weight, tf.float32)) log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh), tf.zeros([], log_accept_thresh.dtype), log_accept_thresh) u = tf.math.log1p(-tf.random.uniform(shape=[batch_size], dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh choose_new_state = is_sample_accepted & continue_tree_final new_candidate_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim(choose_new_state, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(candidate_tree_state.state, last_candidate_state.state) ], target=tf.where(choose_new_state, candidate_tree_state.target, last_candidate_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( choose_new_state, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( candidate_tree_state.target_grad_parts, last_candidate_state.target_grad_parts) ], weight=weight_sum) # Update left right information of the trajectory, and check trajectory # level U turn # Alternative approach # left_right_mask = tf.transpose( # tf.tile(tf.one_hot(tf.cast(direction, tf.int32), 2), # [1, initial_step_metastate.candidate_state[0].shape[-1], 1]), # [2, 0, 1]) # trajactory_state_left_right = tf.where( # tf.equal(left_right_mask, 0.), # trajactory_state_left_right, # tf.tile(tree_final_states[1][0][tf.newaxis, ...], [2, 1, 1])) new_step_state = tf.nest.pack_sequence_as( initial_step_state, [ # Alternative approach: # tf.where(tf.equal(left_right_mask, 0.), # v, # tf.tile(r[tf.newaxis], # tf.concat([[2], tf.ones_like(tf.shape(r))], 0))) tf.tensor_scatter_nd_update(v, left_right_index, r) for v, r in zip(tf.nest.flatten(initial_step_state), tf.nest.flatten(tree_final_states)) ]) no_u_turns_trajectory = has_not_u_turn( [s[0] for s in new_step_state.state], [m[0] for m in new_step_state.momentum], [s[1] for s in new_step_state.state], [m[1] for m in new_step_state.momentum]) new_step_metastate = TreeDoublingMetaState( candidate_state=new_candidate_state, is_accepted=choose_new_state | initial_step_metastate.is_accepted, energy_diff_sum=(energy_diff_tree_sum + initial_step_metastate.energy_diff_sum), continue_tree=continue_tree_final & no_u_turns_trajectory, not_divergence=final_not_divergence, leapfrog_count=(initial_step_metastate.leapfrog_count + leapfrogs_taken)) return iter_ + 1, new_step_state, new_step_metastate
def _loop_build_sub_tree(self, directions, integrator, log_slice_sample, init_energy, iter_, energy_diff_sum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) # Save state and momentum at odd step, check U turn at even step. # Note that here we also write to a Placeholder at even step to avoid # using tf.cond index = iter_ // 2 if USE_RAGGED_TENSOR: write_index_ = self.write_instruction[index] else: write_index_ = tf.switch_case(index, self.write_instruction) write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_, self.max_tree_depth) if USE_TENSORARRAY: momentum_state_memory = MomentumStateSwap( momentum_swap=[ old.write(write_index, new) for old, new in zip( momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ old.write(write_index, new) for old, new in zip( momentum_state_memory.state_swap, next_state_parts) ]) else: momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip( momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [[write_index]], [new]) for old, new in zip(momentum_state_memory.state_swap, next_state_parts) ]) batch_size = prefer_static.size(next_target) has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool) if USE_RAGGED_TENSOR: no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, iter_ // 2, directions, momentum_state_memory, next_momentum_parts, next_state_parts)) else: f = lambda int_iter: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda self.read_instruction, int_iter, directions, momentum_state_memory, next_momentum_parts, next_state_parts) branch_excution = { x: functools.partial(f, x) for x in range(len(self.read_instruction)) } no_u_turns_within_tree = tf.cond( tf.equal(iter_ % 2, 0), lambda: has_not_u_turn_at_even_step, lambda: tf.switch_case(iter_ // 2, branch_excution)) energy = compute_hamiltonian(next_target, next_momentum_parts) energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-tf.random.uniform(shape=[batch_size], dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim(is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where(is_sample_accepted, next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _expand_dims_under_batch_dim( is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, tf.ones([batch_size], dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, energy_diff_sum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def _loop_build_sub_tree(self, directions, integrator, current_step_meta_info, iter_, energy_diff_sum_previous, momentum_cumsum_previous, leapfrogs_taken, prev_tree_state, candidate_tree_state, continue_tree_previous, not_divergent_previous, momentum_state_memory): """Base case in tree doubling.""" with tf.name_scope('loop_build_sub_tree'): # Take one leapfrog step in the direction v and check divergence [ next_momentum_parts, next_state_parts, next_target, next_target_grad_parts ] = integrator(prev_tree_state.momentum, prev_tree_state.state, prev_tree_state.target, prev_tree_state.target_grad_parts) next_tree_state = TreeDoublingState( momentum=next_momentum_parts, state=next_state_parts, target=next_target, target_grad_parts=next_target_grad_parts) momentum_cumsum = [ p0 + p1 for p0, p1 in zip(momentum_cumsum_previous, next_momentum_parts) ] # If the tree have not yet terminated previously, we count this leapfrog. leapfrogs_taken = tf.where(continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken) write_instruction = current_step_meta_info.write_instruction read_instruction = current_step_meta_info.read_instruction init_energy = current_step_meta_info.init_energy if GENERALIZED_UTURN: state_to_write = momentum_cumsum_previous state_to_check = momentum_cumsum else: state_to_write = next_state_parts state_to_check = next_state_parts batch_shape = prefer_static.shape(next_target) has_not_u_turn_init = prefer_static.ones(batch_shape, dtype=tf.bool) read_index = read_instruction.gather([iter_])[0] no_u_turns_within_tree = has_not_u_turn_at_all_index( # pylint: disable=g-long-lambda read_index, directions, momentum_state_memory, next_momentum_parts, state_to_check, has_not_u_turn_init, log_prob_rank=prefer_static.rank(next_target)) # Get index to write state into memory swap write_index = write_instruction.gather([iter_]) momentum_state_memory = MomentumStateSwap( momentum_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.momentum_swap, next_momentum_parts) ], state_swap=[ tf.tensor_scatter_nd_update(old, [write_index], [new]) for old, new in zip(momentum_state_memory.state_swap, state_to_write) ]) energy = compute_hamiltonian(next_target, next_momentum_parts) current_energy = tf.where(tf.math.is_nan(energy), tf.constant(-np.inf, dtype=energy.dtype), energy) energy_diff = current_energy - init_energy if MULTINOMIAL_SAMPLE: not_divergent = -energy_diff < self.max_energy_diff weight_sum = log_add_exp(candidate_tree_state.weight, energy_diff) log_accept_thresh = energy_diff - weight_sum else: log_slice_sample = current_step_meta_info.log_slice_sample not_divergent = log_slice_sample - energy_diff < self.max_energy_diff # Uniform sampling on the trajectory within the subtree across valid # samples. is_valid = log_slice_sample <= energy_diff weight_sum = tf.where(is_valid, candidate_tree_state.weight + 1, candidate_tree_state.weight) log_accept_thresh = tf.where( is_valid, -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)), tf.constant(-np.inf, dtype=tf.float32)) u = tf.math.log1p(-tf.random.uniform(shape=batch_shape, dtype=log_accept_thresh.dtype, seed=self._seed_stream())) is_sample_accepted = u <= log_accept_thresh next_candidate_tree_state = TreeDoublingStateCandidate( state=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(s0)), s0, s1) for s0, s1 in zip(next_state_parts, candidate_tree_state.state) ], target=tf.where( _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(next_target)), next_target, candidate_tree_state.target), target_grad_parts=[ tf.where( # pylint: disable=g-complex-comprehension _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(grad0)), grad0, grad1) for grad0, grad1 in zip( next_target_grad_parts, candidate_tree_state.target_grad_parts) ], energy=tf.where( _rightmost_expand_to_rank(is_sample_accepted, prefer_static.rank(next_target)), current_energy, candidate_tree_state.energy), weight=weight_sum) continue_tree = not_divergent & continue_tree_previous continue_tree_next = no_u_turns_within_tree & continue_tree not_divergent_tokeep = tf.where( continue_tree_previous, not_divergent, prefer_static.ones(batch_shape, dtype=tf.bool)) # min(1., exp(energy_diff)). exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.)) energy_diff_sum = tf.where( continue_tree, energy_diff_sum_previous + exp_energy_diff, energy_diff_sum_previous) return ( iter_ + 1, energy_diff_sum, momentum_cumsum, leapfrogs_taken, next_tree_state, next_candidate_tree_state, continue_tree_next, not_divergent_previous & not_divergent_tokeep, momentum_state_memory, )
def projection(self, X, y, resp, weights, t_range): print("----Tracing___projection") @tf.function def expec_ll(alpha1, alpha2): #, i, j, c1, c2): temp_lower = tf.identity(self.lower) temp_upper = tf.identity(self.upper) self.lower.assign(alpha1) self.upper.assign(alpha2) log_cond = self.expected_ll(X, y, resp, weights) self.lower.assign(temp_lower) self.upper.assign(temp_upper) #tf.print(self.lower) #tf.print(self.upper) return log_cond tmp_indexes = tf.where( tf.less(self.no_ovelap_test(), -self.theta / 50.)) #tf.print("number of remaining overlapping: ", tf.size(tmp_indexes)) while not (tf.equal(tf.size(tmp_indexes), 0)): #print("while looop") classes = tf.cast(tf.math.floordiv(tmp_indexes, self.n_components), tf.int32) good_indexes = tf.cast( tf.math.floormod(tmp_indexes, self.n_components), tf.int32) score = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, name="score", clear_after_read=False) #Matrix of updates alpha1 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, name="alpha1", clear_after_read=False) alpha2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True, name="alpha2", clear_after_read=False) #tf.print(classes) #tf.print(good_indexes) #For each update, compute the entropy for it in tf.range(tf.minimum(tf.constant(self.data_dim), 20)): #print("toto") d = t_range[it] #tf.print(self.lower) #print("d loooooop") if self.upper[classes[0, 0], good_indexes[0, 0], d] > self.upper[classes[0, 1], good_indexes[0, 1], d]: alpha1 = alpha1.write( 2 * it, tf.tensor_scatter_nd_update( self.lower, [[classes[0, 0], good_indexes[0, 0], d]], [self.upper[classes[0, 1], good_indexes[0, 1], d] ])) alpha2 = alpha2.write(2 * it, self.upper) score = score.write( 2 * it, expec_ll(alpha1.read(2 * it), alpha2.read(2 * it))) #,good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1])) else: alpha1 = alpha1.write( 2 * it, tf.tensor_scatter_nd_update( self.lower, [[classes[0, 1], good_indexes[0, 1], d]], [self.upper[classes[0, 0], good_indexes[0, 0], d] ])) alpha2 = alpha2.write(2 * it, self.upper) score = score.write( 2 * it, expec_ll(alpha1.read(2 * it), alpha2.read(2 * it))) #, #good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1])) if self.lower[classes[0, 0], good_indexes[0, 0], d] < self.lower[classes[0, 1], good_indexes[0, 1], d]: alpha2 = alpha2.write( 2 * it + 1, tf.tensor_scatter_nd_update( self.upper, [[classes[0, 0], good_indexes[0, 0], d]], [self.lower[classes[0, 1], good_indexes[0, 1], d] ])) alpha1 = alpha1.write(2 * it + 1, self.lower) score = score.write( 2 * it + 1, expec_ll(alpha1.read(2 * it + 1), alpha2.read(2 * it + 1))) #, # good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1])) else: alpha2 = alpha2.write( 2 * it + 1, tf.tensor_scatter_nd_update( self.upper, [[classes[0, 1], good_indexes[0, 1], d]], [self.lower[classes[0, 0], good_indexes[0, 0], d] ])) alpha1 = alpha1.write(2 * it + 1, self.lower) score = score.write( 2 * it + 1, expec_ll(alpha1.read(2 * it + 1), alpha2.read(2 * it + 1))) #tf.print(score.stack()) #change the values of alpha corresponding to the lowest update true_score = score.stack() #ind = tf.cast(tf.math.argmin(tf.boolean_mask(true_score, tf.greater(true_score,0))), tf.int32) ind = tf.cast(tf.math.argmin(true_score), tf.int32) #tf.print(ind) self.lower.assign(alpha1.read(ind)) self.upper.assign(alpha2.read(ind)) #Re-compute the no-overlapp tmp_indexes = tf.where(tf.less(self.no_ovelap_test(), -self.theta)) #if not(tf.equal(tf.size(tmp_indexes), 0)): # tf.print(self.no_ovelap_test()[tmp_indexes[0,0], tmp_indexes[0,1]]) tf.print("number of remaining overlapping: ", tf.size(tmp_indexes))