예제 #1
0
    def tf_hilbert(x, axis=-1):
        '''Performs 1d hilbert similar to scipy'''

        # Change axes to be most inner for fft
        axis = tf.constant(axis)
        if axis < 0: axis = tf.rank(x) + axis
        axes = tf.range(tf.rank(x))
        axes = tf.math.mod(axes - tf.reduce_max(axes) + axis, tf.size(axes))
        x = tf.transpose(x, perm=axes)

        # Apply fft
        x = tf.cast(x, dtype=tf.complex64)
        Xf = tf.signal.fft(x)

        # Create 2U
        N = tf.shape(Xf)[-1]
        h = tf.cast(tf.ones([N // 2 + 1]) * 2, Xf.dtype)
        if tf.math.mod(N, 2) == 0:
            h = tf.tensor_scatter_nd_update(h, [[0], [tf.size(h) - 1]], [1, 1])
        else:
            h = tf.tensor_scatter_nd_update(h, [[0]], [1])
        h = tf.concat([h, tf.zeros(N - tf.size(h), dtype=h.dtype)], axis=0)

        # Apply ifft and hilbert
        x = tf.signal.ifft(Xf * h)

        # Change axes back
        x = tf.transpose(x, perm=tf.argsort(axes))
        return x
예제 #2
0
  def _parse_record(self, record) -> NestedMap:
    """Reads and parses a single record."""
    p = self.params
    name_to_features = {
        'input_ids':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
        'input_mask':
            tf.io.FixedLenFeature([p.max_sequence_length], tf.int64),
        'masked_lm_positions':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
        'masked_lm_ids':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64),
        'masked_lm_weights':
            tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32),
    }
    example = tf.io.parse_single_example(record, name_to_features)
    mask_length = tf.cast(
        tf.reduce_sum(example['masked_lm_weights']), dtype=tf.int32)
    masked_lm_positions = tf.slice(example['masked_lm_positions'], [0],
                                   [mask_length])
    masked_lm_ids = tf.cast(
        tf.slice(example['masked_lm_ids'], [0], [mask_length]), dtype=tf.int32)
    ret = py_utils.NestedMap()
    ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32)
    # Get back non-masked, original ids.
    ret.labels = tf.tensor_scatter_nd_update(
        tensor=ret.masked_ids,
        indices=tf.reshape(masked_lm_positions, [-1, 1]),
        updates=masked_lm_ids)
    ret.masked_pos = tf.tensor_scatter_nd_update(
        tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32),
        indices=tf.reshape(masked_lm_positions, [-1, 1]),
        updates=tf.ones_like(masked_lm_ids, dtype=tf.float32))
    ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32)

    first_eos_idx = tf.where(tf.math.equal(ret.labels, p.eos_token_id))[0][0]

    def remove_first_eos(x):
      # We remove the element at position `first_eos_idx`, and pad with 0
      # to keep length unchanged.
      zero = tf.constant(0, shape=(1,), dtype=x.dtype)
      return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero], axis=0)

    ret = ret.Transform(remove_first_eos)
    ret.paddings = 1.0 - ret.segment_ids
    pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32)
    ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32)

    if p.remask:
      new_masked_ids, new_masked_pos = self.mlm.FProp(None, ret.labels,
                                                      ret.paddings)
      ret.masked_ids = new_masked_ids
      ret.masked_pos = new_masked_pos
    return ret
예제 #3
0
        def loop_body(index, samples):
            """Loop for iterative pixel sampling.
            Args:
            index: 0D `Tensor` of type `int32`. Index of the current pixel.
            samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
              pixel `[index]`, with dimensions `[batch_size, height, width,
              num_channels]`.
            Returns:
            samples: 4D `Tensor`. Images with pixels sampled in raster order, up to
              and including pixel `[index]`, with dimensions `[batch_size, height,
              width, num_channels]`.
            """
            inputs = samples if conditional_input is None else [samples, h]
            params = self.network(inputs, training=training)
            samples_new = self._sample_channels(*params, seed=seed)

            # Update the current pixel
            samples = tf.transpose(samples, [1, 2, 3, 0])
            samples_new = tf.transpose(samples_new, [1, 2, 3, 0])
            row, col = index // image_width, index % image_width
            updates = samples_new[row, col, ...][tf.newaxis, ...]
            samples = tf.tensor_scatter_nd_update(samples, [[row, col]],
                                                  updates)
            samples = tf.transpose(samples, [3, 0, 1, 2])

            return index + 1, samples
예제 #4
0
def contrastive_loss(similarity_matrix,
                     metric_values,
                     temperature,
                     coupling_temperature=1.0,
                     use_coupling_weights=True):
  """Contrative Loss with soft coupling."""
  logging.info('Using alternative contrastive loss.')
  metric_shape = tf.shape(metric_values)
  similarity_matrix /= temperature
  neg_logits1 = similarity_matrix

  col_indices = tf.cast(tf.argmin(metric_values, axis=1), dtype=tf.int32)
  pos_indices1 = tf.stack(
      (tf.range(metric_shape[0], dtype=tf.int32), col_indices), axis=1)
  pos_logits1 = tf.gather_nd(similarity_matrix, pos_indices1)

  if use_coupling_weights:
    metric_values /= coupling_temperature
    coupling = tf.exp(-metric_values)
    pos_weights1 = -tf.gather_nd(metric_values, pos_indices1)
    pos_logits1 += pos_weights1
    negative_weights = tf.math.log((1.0 - coupling) + EPS)
    neg_logits1 += tf.tensor_scatter_nd_update(negative_weights, pos_indices1,
                                               pos_weights1)
  neg_logits1 = tf.math.reduce_logsumexp(neg_logits1, axis=1)
  return tf.reduce_mean(neg_logits1 - pos_logits1)
예제 #5
0
    def _inverse_event_shape_tensor(self, output_shapes):
        """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.

    Args:
      output_shapes: An iterable of `Tensor`, `int32` vectors indicating
        event-shapes passed into `inverse` function. The length of the iterable
        must be equal to the number of splits.

    Returns:
      inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
        event-portion shape after applying `inverse`.
    """
        # Validate `output_shapes` statically if possible and get assertions.
        is_validated = self._validate_output_shapes([
            tensorshape_util.constant_value_as_shape(s) for s in output_shapes
        ])
        if is_validated or not self.validate_args:
            assertions = []
        else:
            assertions = self._validate_output_shape_tensors(output_shapes)

        with tf.control_dependencies(assertions):
            total_size = tf.reduce_sum([t[self.axis] for t in output_shapes])
            inverse_event_shape = tf.tensor_scatter_nd_update(
                output_shapes[0],
                [[prefer_static.rank_from_shape(output_shapes[0]) + self.axis]
                 ], [total_size])
            return tf.identity(
                tf.convert_to_tensor(inverse_event_shape,
                                     dtype_hint=tf.int32,
                                     name='inverse_event_shape'))
예제 #6
0
def _get_q_slice(q, k, ind, b=None, batch_shape=None):
    """Returns `q1[i]` or `q0[j]` for a batch of indices `i` or `j`."""
    q_ind = tf.concat([ind, tf.expand_dims(tf.gather_nd(k, ind), -1)], axis=1)
    b_updates = tf.gather_nd(q, q_ind)
    if b is None:
        return tf.scatter_nd(ind, b_updates, batch_shape)
    return tf.tensor_scatter_nd_update(b, ind, b_updates)
예제 #7
0
def inplace_update_i(inp_tensor, updates, i):
  """Inplace update a tensor. B: batch_size, L: tensor length."""
  batch_size = inp_tensor.shape[0]
  indices = tf.stack([
      tf.range(batch_size, dtype=tf.int32),
      tf.fill([batch_size], tf.cast(i, tf.int32))
  ], axis=-1)
  return tf.tensor_scatter_nd_update(inp_tensor, indices, updates)
예제 #8
0
def _update_i(tensor_BxNxT, updates_BxN, i):
  B, N, T = tensor_BxNxT.shape
  tensor_BNxT = tf.reshape(tensor_BxNxT, [-1, T])
  updates_BN = tf.reshape(updates_BxN, [-1])
  batch_BN = tf.range(B * N, dtype=tf.int32)
  i_BN = tf.fill([B * N], i)
  ind_BNx2 = tf.stack([batch_BN, i_BN], axis=-1)
  tensor_BNxT = tf.tensor_scatter_nd_update(tensor_BNxT, ind_BNx2, updates_BN)
  return tf.reshape(tensor_BNxT, [B, N, T])
예제 #9
0
def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
    """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
    if not source and not destination:
        return a

    a = asarray(a).data

    if isinstance(source, int):
        source = (source, )
    if isinstance(destination, int):
        destination = (destination, )

    a_rank = utils._maybe_static(tf.rank(a))  # pylint: disable=protected-access

    def _correct_axis(axis, rank):
        if axis < 0:
            return axis + rank
        return axis

    source = tuple(_correct_axis(axis, a_rank) for axis in source)
    destination = tuple(_correct_axis(axis, a_rank) for axis in destination)

    if a.shape.rank is not None:
        perm = [i for i in range(a_rank) if i not in source]
        for dest, src in sorted(zip(destination, source)):
            assert dest <= len(perm)
            perm.insert(dest, src)
    else:
        r = tf.range(a_rank)

        def _remove_indices(a, b):
            """Remove indices (`b`) from `a`."""
            items = tf.unstack(tf.sort(tf.stack(b)), num=len(b))

            i = 0
            result = []

            for item in items:
                result.append(a[i:item])
                i = item + 1

            result.append(a[i:])

            return tf.concat(result, 0)

        minus_sources = _remove_indices(r, source)
        minus_dest = _remove_indices(r, destination)

        perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources,
                             [a_rank])
        perm = tf.tensor_scatter_nd_update(perm,
                                           tf.expand_dims(destination,
                                                          1), source)
    a = tf.transpose(a, perm)

    return utils.tensor_to_ndarray(a)
예제 #10
0
def _get_endpoint_a(i, j, q1_i, q0_j, batch_shape):
    """Determine the beginning of the interval, `a`."""
    # if i < 0: a = q0[j]
    i_lt_0 = tf.less(i, 0)
    ind = tf.where(i_lt_0)
    a_update = tf.gather_nd(q0_j, ind)
    a = tf.scatter_nd(ind, a_update, batch_shape)

    # elif j < 0: a = q1[i]
    j_lt_0 = tf.less(j, 0)
    ind = tf.where(j_lt_0)
    a_update = tf.gather_nd(q1_i, ind)
    a = tf.tensor_scatter_nd_update(a, ind, a_update)

    # else: a = max(q0[j], q1[i])
    ind = tf.where(~(i_lt_0 | j_lt_0))
    q_max = tf.maximum(q0_j, q1_i)
    a_update = tf.gather_nd(q_max, ind)
    a = tf.tensor_scatter_nd_update(a, ind, a_update)
    return a
예제 #11
0
def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
  a = asarray(a)

  a_rank = tf.rank(a)
  if axis1 < 0:
    axis1 += a_rank
  if axis2 < 0:
    axis2 += a_rank

  perm = tf.range(a_rank)
  perm = tf.tensor_scatter_nd_update(perm, [[axis1], [axis2]], [axis2, axis1])
  a = tf.transpose(a, perm)

  return utils.tensor_to_ndarray(a)
예제 #12
0
def _masked_tensor_scatter_nd_update(tensor, indices, updates, mask=None):
    """Performs tensor_scatter_nd_update with masked updates."""
    if mask is None:
        return tf.tensor_scatter_nd_update(tensor, indices, updates)
    # We have to handle two cases: (1) all updates are masked and (2) there is
    # at least one update not masked. We do not want to unnessesarily recreate the
    # cache and to support TPU we cannot use condtional statements.
    pred = tf.reduce_any(tf.cast(mask, tf.bool))
    pred_indices = tf.cast(pred, indices.dtype)
    pred_updates = tf.cast(pred, updates.dtype)
    indices_mask = tf.expand_dims(tf.cast(mask, dtype=indices.dtype), axis=1)
    updates_mask = _get_broadcastable_mask(mask, updates)
    # If there is at least one unmasked element we will get it here. We will
    # replace all masked indices and updates with this element.
    _, not_masked = tf.math.top_k(indices_mask[:, 0])
    not_masked = not_masked[0]
    indices = indices * indices_mask + indices[not_masked] * (1 - indices_mask)
    updates = updates * updates_mask + updates[not_masked] * (1 - updates_mask)
    # If all elements are masked, then indices will become all zero and we
    # "update" the tensor with the value at the zeroth index, effectively not
    # changing the tensor.
    indices = pred_indices * indices
    updates = pred_updates * updates + (1 - pred_updates) * tensor[0]
    return tf.tensor_scatter_nd_update(tensor, indices, updates)
예제 #13
0
 def body(i, triangular_factor):
   column_head = triangular_factor[..., i, i, tf.newaxis]
   column_tail = triangular_factor[..., i+1:, i]
   rescaled_tail = column_tail / column_head
   triangular_factor = tf.tensor_scatter_nd_update(
       triangular_factor,
       slix[..., i+1:, i],
       rescaled_tail)
   triangular_factor = tf.tensor_scatter_nd_sub(
       triangular_factor,
       slix[..., i+1:, i+1:],
       tf.linalg.band_part(
           tf.einsum('...i,...j->...ij', column_tail, rescaled_tail),
           num_lower=-1, num_upper=0))
   return i+1, triangular_factor
예제 #14
0
def rot90(m, k=1, axes=(0, 1)):  # pylint: disable=missing-docstring
  m_rank = tf.rank(m)
  ax1, ax2 = utils._canonicalize_axes(axes, m_rank)  # pylint: disable=protected-access

  k = k % 4
  if k == 0:
    return m
  elif k == 2:
    return flip(flip(m, ax1), ax2)
  else:
    perm = tf.range(m_rank)
    perm = tf.tensor_scatter_nd_update(perm, [[ax1], [ax2]], [ax2, ax1])

    if k == 1:
      return transpose(flip(m, ax2), perm)
    else:
      return flip(transpose(m, perm), ax2)
예제 #15
0
    def call(self, inputs):
        value, index = inputs
        if self.cache.shape == inputs[0].shape:
            self.cache = value
            return value

        shape = self.cache.shape.as_list()
        num_index_axes = index.shape[0]
        num_batch_axes = self.num_batch_axes
        num_feature_axes = len(shape) - num_index_axes - num_batch_axes
        features_shape = shape[num_batch_axes + num_index_axes:]
        batch_shape = shape[:num_batch_axes]

        value_index_shape = tf.shape(value)[num_batch_axes:-num_feature_axes]
        if tf.reduce_max(value_index_shape) > 1:
            # This is a block update starting at index.
            value_ranges = []
            for i, s in enumerate(tf.unstack(value_index_shape)):
                curr_range = tf.range(index[i], index[i] + s)
                value_ranges.append(curr_range)

            batch_ranges = [tf.range(s) for s in batch_shape]

            mesh = tf.meshgrid(*(batch_ranges + value_ranges), indexing='ij')
            indices = tf.stack(mesh, axis=-1)
            indices = tf.reshape(indices,
                                 [-1, num_index_axes + num_batch_axes])
        else:
            # This is a single update at index position.
            batch_ranges = [tf.range(s) for s in batch_shape]
            mesh = tf.meshgrid(*batch_ranges, indexing='ij')
            batch_indices = tf.stack(mesh, axis=-1)
            batch_indices = tf.reshape(batch_indices, [-1, num_batch_axes])

            # Add leading axes to nd-index and tile to get batched indices.
            shape_indices = tf.reshape(index, [1] * num_batch_axes + [-1])
            shape_indices = tf.tile(shape_indices, batch_shape + [1])
            shape_indices = tf.reshape(shape_indices, [-1, num_index_axes])

            indices = tf.concat([batch_indices, shape_indices], axis=-1)

        # We need to squeeze nd-axes from value before updating.
        value = tf.reshape(value, [-1] + features_shape)
        self.cache = tf.tensor_scatter_nd_update(self.cache, indices, value)
        return self.cache
  def test_arrays_all_close(self, dtype):
    """Tests the _arrays_all_close() helper function."""
    with self.subTest("Expected pass"):
      a = tf.ones([4, 2, 3], dtype)
      b = a * 0.99
      atol = a * 0.02
      gmb_utils.arrays_all_close(self, a, b, atol)

    with self.subTest("Shape mismatch"):
      a = tf.ones([4, 2, 3], dtype)
      b = tf.ones([4, 1, 3], dtype) * 0.99
      atol = a * 0.02
      with self.assertRaises(ValueError):
        gmb_utils.arrays_all_close(self, a, b, atol)

    with self.subTest("Values not close"):
      a = tf.ones([4, 2, 3], dtype)
      b = tf.ones([4, 2, 3], dtype) * 0.99
      c = tf.tensor_scatter_nd_update(b, [(1, 1, 1)], [0.95])
      atol = a * 0.02
      with self.assertRaises(ValueError):
        gmb_utils.arrays_all_close(self, a, c, atol)
예제 #17
0
def _set_vector_index_unbatched(v, idx, x):
    """Mutation-free equivalent of `v[idx] = x."""
    return tf.tensor_scatter_nd_update(v, indices=[[idx]], updates=[x])
예제 #18
0
def _safe_tensor_scatter_nd_update(tensor, indices, updates):
  if tensorshape_util.num_elements(tensor.shape) == 0:
    return tensor
  return tf.tensor_scatter_nd_update(tensor, indices, updates)
예제 #19
0
 def set_negative_scores(scores, indices):
     indices_2d = tf.stack(
         [tf.range(bsz, dtype=indices.dtype), indices], axis=1)
     return tf.tensor_scatter_nd_update(
         scores, indices_2d, tf.fill(tf.shape(indices), -1.0))
예제 #20
0
파일: nn.py 프로젝트: next-mooon/ddsp
    def call(self, x, training=False):
        x_flat = tf.reshape(x, shape=(-1, self.depth))

        # Split each input vector into one segment per head.
        x_flat_split = tf.split(x_flat, self.num_heads, axis=1)
        x_flat = tf.concat(x_flat_split, axis=0)

        if training:
            # Figure out which centroids we want to keep, and which we want to
            # restart.
            n = x_flat.shape[0]
            keep = self.counts * self.k > self.restart_threshold * n
            restart = tf.math.logical_not(keep)

            # Replace centroids to restart with elements from the batch, using samples
            # from a uniform distribution as a fallback in case we need to restart
            # more centroids than we have elements in the batch.
            restart_idx = tf.squeeze(tf.where(restart), -1)
            n_replace = tf.minimum(tf.shape(restart_idx)[0], x_flat.shape[0])
            e_restart = tf.tensor_scatter_nd_update(
                tf.random.uniform([self.k, self.depth // self.num_heads]),
                tf.expand_dims(restart_idx[:n_replace], 1),
                tf.random.shuffle(x_flat)[:n_replace])

            # Compute the values of the centroids we want to keep by dividing the
            # summed vectors by the corresponding counts.
            e = tf.where(
                tf.expand_dims(keep, 1),
                tf.math.divide_no_nan(self.sums,
                                      tf.expand_dims(self.counts, 1)),
                e_restart)

        else:
            # If not training, just use the centroids as is with no restarts.
            e = tf.math.divide_no_nan(self.sums,
                                      tf.expand_dims(self.counts, 1))

        # Compute distance between each input vector and each cluster center.
        distances = (tf.expand_dims(tf.reduce_sum(x_flat**2, axis=1), 1) -
                     2 * tf.matmul(x_flat, tf.transpose(e)) +
                     tf.expand_dims(tf.reduce_sum(e**2, axis=1), 0))

        # Find nearest cluster center for each input vector.
        c = tf.argmin(distances, axis=1)

        # Quantize input vectors with straight-through estimator.
        z = tf.nn.embedding_lookup(e, c)
        z_split = tf.split(z, self.num_heads, axis=0)
        z = tf.concat(z_split, axis=1)
        z = tf.reshape(z, tf.shape(x))
        z = x + tf.stop_gradient(z - x)

        if training:
            # Compute cluster counts and vector sums over the batch.
            oh = tf.one_hot(indices=c, depth=self.k)
            counts = tf.reduce_sum(oh, axis=0)
            sums = tf.matmul(oh, x_flat, transpose_a=True)

            # Apply exponential moving average to cluster counts and vector sums.
            self.counts.assign_sub((1 - self.gamma) * (self.counts - counts))
            self.sums.assign_sub((1 - self.gamma) * (self.sums - sums))

        c_split = tf.split(c, self.num_heads, axis=0)
        c = tf.stack(c_split, axis=1)
        c = tf.reshape(c,
                       tf.concat([tf.shape(x)[:-1], [self.num_heads]], axis=0))

        return z, c
예제 #21
0
        def collater_fn(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
            batch = mm_collater_fn(batch)

            retrieve_masked = config.get('retrieve_masked', False)

            # Subselect mentions for which to retrieve corresponding memory.
            # We want to sample mentions which are linked, not masked, and not padded.
            scores = tf.random.uniform(
                tf.shape(batch['mention_target_is_masked'])) + 2 * tf.cast(
                    batch['mention_target_weights'], tf.float32)

            if not retrieve_masked:
                scores -= tf.cast(batch['mention_target_is_masked'],
                                  tf.float32)

            _, mention_target_retrieval_indices = tf.math.top_k(
                scores, k=max_retrieval_indices)

            mention_retrieval_indices = tf.gather(
                batch['mention_target_indices'],
                mention_target_retrieval_indices)
            retrieval_mention_mask = tf.gather(
                batch['mention_target_weights'],
                mention_target_retrieval_indices)
            # set weight to 0 for masked retrievals if we do not want to include these
            if not retrieve_masked:
                retrieval_mention_mask *= tf.gather(
                    1 - tf.cast(batch['mention_target_is_masked'], tf.int32),
                    mention_target_retrieval_indices)

            retrieval_mention_start_positions = tf.gather(
                batch['mention_start_positions'], mention_retrieval_indices)
            retrieval_text_identifiers = tf.gather(batch['text_identifiers'],
                                                   mention_retrieval_indices)
            retrieval_mention_hash = mention_preprocess_utils.modified_cantor_pairing(
                tf.cast(retrieval_mention_start_positions, tf.int64),
                retrieval_text_identifiers)
            retrieval_mention_hash = tf.cast(retrieval_mention_hash, tf.int32)

            retrieval_mention_sort_ids = tf.searchsorted(
                memory_hash_sorted, retrieval_mention_hash)

            # Searchsorted does not check whether value is present in array, just
            # finds insertion point. Here we check and set to default retrieval if not
            # present.
            hash_not_present_mask = tf.not_equal(
                retrieval_mention_hash,
                tf.gather(memory_hash_sorted, retrieval_mention_sort_ids))
            hash_not_present = tf.where(hash_not_present_mask)
            update_values = tf.fill((tf.shape(hash_not_present)[0], ),
                                    tf.shape(hash_sorted_idx)[0] - 1)
            retrieval_mention_sort_ids = tf.tensor_scatter_nd_update(
                retrieval_mention_sort_ids, hash_not_present, update_values)

            # Set mask to 0 if no mention is found
            batch['retrieval_mention_mask'] = retrieval_mention_mask * (
                1 - tf.cast(hash_not_present_mask, tf.int32))

            retrieval_mention_ids = tf.gather(hash_sorted_idx,
                                              retrieval_mention_sort_ids)
            retrieval_mention_values = tf.gather(memory_table,
                                                 retrieval_mention_ids)
            # Match passage entity_ids with memory entity ids as sanity check.
            if memory_entity_pattern:
                retrieval_memory_entity_ids = tf.gather(
                    memory_entity_ids, retrieval_mention_ids)
                retrieval_passage_entity_ids = tf.gather(
                    tf.cast(batch['mention_target_ids'], tf.int32),
                    mention_target_retrieval_indices)
                entity_does_not_match = tf.not_equal(
                    retrieval_memory_entity_ids, retrieval_passage_entity_ids)

                batch['entity_does_not_match'] = tf.logical_and(
                    entity_does_not_match,
                    tf.cast(batch['retrieval_mention_mask'], tf.bool))

            batch['retrieval_mention_values'] = retrieval_mention_values
            batch['retrieval_mention_scores'] = tf.ones_like(
                batch['retrieval_mention_mask'])
            batch['retrieval_mention_batch_positions'] = tf.gather(
                batch['mention_batch_positions'], mention_retrieval_indices)
            batch['retrieval_mention_start_positions'] = retrieval_mention_start_positions  # pylint: disable=line-too-long
            batch['retrieval_mention_end_positions'] = tf.gather(
                batch['mention_end_positions'], mention_retrieval_indices)
            batch['mention_retrieval_indices'] = mention_retrieval_indices

            return batch
예제 #22
0
def _update_batch(ind, b_update, b=None, batch_shape=None):
    """Updates a batch of `i`, `j`, `q1[i]` or `q0[j]`."""
    updates = tf.gather_nd(b_update, ind)
    if b is None:
        return tf.scatter_nd(ind, updates, batch_shape)
    return tf.tensor_scatter_nd_update(b, ind, updates)
예제 #23
0
def _replace_at_index(x, index, replacement):
  """Replaces an element at supplied index."""
  return tf.tensor_scatter_nd_update(x, [[index]], [replacement])
예제 #24
0
 def scatter_update_2D_slice(self, tensor, indices, updates):
     return tf.tensor_scatter_nd_update(tensor, indices, updates)
예제 #25
0
  def _loop_build_sub_tree(
      self, direction, log_slice_sample,
      iter_, prev_tree_state, candidate_tree_state,
      continue_tree_previous, trace_arrays):
    """Base case in tree doubling."""
    with tf.name_scope('loop_build_sub_tree'):
      # Take one leapfrog step in the direction v and check divergence
      directions_expanded = [
          _expand_dims_under_batch_dim(direction, prefer_static.rank(state))
          for state in prev_tree_state.state]
      integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
          self.target_log_prob_fn,
          step_sizes=[tf.where(direction, ss, -ss)
                      for direction, ss in zip(
                          directions_expanded, self.step_size)],
          num_steps=self.unrolled_leapfrog_steps)
      [
          next_momentum_parts,
          next_state_parts,
          next_target,
          next_target_grad_parts
      ] = integrator(prev_tree_state.momentum,
                     prev_tree_state.state,
                     prev_tree_state.target,
                     prev_tree_state.target_grad_parts)

      next_tree_state = TreeDoublingState(
          momentum=next_momentum_parts,
          state=next_state_parts,
          target=next_target,
          target_grad_parts=next_target_grad_parts)

      # Save state and momentum at odd step, check U turn at even step.
      # Note that here we also write to a Placeholder at even step to avoid
      # using tf.cond
      index = iter_ // 2
      if USE_RAGGED_TENSOR:
        write_index_ = self.write_instruction[index]
      else:
        write_index_ = tf.switch_case(index, self.write_instruction)

      write_index = tf.where(tf.equal(iter_ % 2, 0),
                             write_index_, self.max_tree_depth)

      if USE_TENSORARRAY:
        trace_arrays = TraceArrays(
            momentum_swap=[
                old.write(write_index, new) for old, new in
                zip(trace_arrays.momentum_swap, next_momentum_parts)],
            state_swap=[
                old.write(write_index, new) for old, new in
                zip(trace_arrays.state_swap, next_state_parts)])
      else:
        trace_arrays = TraceArrays(
            momentum_swap=[
                tf.tensor_scatter_nd_update(old, [[write_index]], [new])
                for old, new in zip(
                    trace_arrays.momentum_swap, next_momentum_parts)],
            state_swap=[
                tf.tensor_scatter_nd_update(old, [[write_index]], [new])
                for old, new in zip(
                    trace_arrays.state_swap, next_state_parts)])
      batch_size = prefer_static.size(next_target)
      has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool)

      if USE_RAGGED_TENSOR:
        no_u_turns_within_tree = tf.cond(
            tf.equal(iter_ % 2, 0),
            lambda: has_not_u_turn_at_even_step,
            lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                self.read_instruction, iter_ // 2, directions_expanded,
                trace_arrays, next_momentum_parts, next_state_parts))
      else:
        f = lambda int_iter: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
            self.read_instruction, int_iter, directions_expanded, trace_arrays,
            next_momentum_parts, next_state_parts)
        branch_excution = {x: functools.partial(f, x)
                           for x in range(len(self.read_instruction))}
        no_u_turns_within_tree = tf.cond(
            tf.equal(iter_ % 2, 0),
            lambda: has_not_u_turn_at_even_step,
            lambda: tf.switch_case(iter_ // 2, branch_excution))

      energy = compute_hamiltonian(next_target, next_momentum_parts)
      valid_candidate = log_slice_sample <= energy

      # Uniform sampling on the trajectory within the subtree
      sample_weight = tf.cast(valid_candidate, TREE_COUNT_DTYPE)
      weight_sum = candidate_tree_state.weight + sample_weight
      log_accept_thresh = tf.math.log(
          tf.cast(sample_weight, tf.float32) /
          tf.cast(weight_sum, tf.float32))
      log_accept_thresh = tf.where(
          tf.math.is_nan(log_accept_thresh),
          tf.zeros([], log_accept_thresh.dtype),
          log_accept_thresh)
      u = tf.math.log1p(-tf.random.uniform(
          shape=[batch_size],
          dtype=tf.float32,
          seed=self._seed_stream()))
      is_sample_accepted = u <= log_accept_thresh

      next_candidate_tree_state = TreeDoublingStateCandidate(
          state=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  _expand_dims_under_batch_dim(
                      is_sample_accepted, prefer_static.rank(s0)), s0, s1)
              for s0, s1 in zip(next_state_parts,
                                candidate_tree_state.state)
          ],
          target=tf.where(is_sample_accepted,
                          next_target,
                          candidate_tree_state.target),
          target_grad_parts=[
              tf.where(  # pylint: disable=g-complex-comprehension
                  _expand_dims_under_batch_dim(
                      is_sample_accepted, prefer_static.rank(grad0)),
                  grad0, grad1)
              for grad0, grad1 in zip(next_target_grad_parts,
                                      candidate_tree_state.target_grad_parts)
          ],
          weight=weight_sum)

      not_divergent = log_slice_sample - energy < self.max_energy_diff
      continue_tree = not_divergent & no_u_turns_within_tree
      continue_tree_next = continue_tree_previous & continue_tree

      return (
          iter_ + 1,
          next_tree_state,
          next_candidate_tree_state,
          continue_tree_next,
          trace_arrays,
      )
예제 #26
0
    def loop_tree_doubling(self, step_size, log_slice_sample, init_energy,
                           momentum_state_memory, iter_, initial_step_state,
                           initial_step_metastate):
        """Main loop for tree doubling."""
        with tf.name_scope('loop_tree_doubling'):
            batch_size = prefer_static.size(init_energy)
            direction = tf.cast(tf.random.uniform(shape=[batch_size],
                                                  minval=0,
                                                  maxval=2,
                                                  dtype=tf.int32,
                                                  seed=self._seed_stream()),
                                dtype=tf.bool)

            left_right_index = tf.concat([
                tf.cast(direction, tf.int32)[..., tf.newaxis],
                tf.range(batch_size, dtype=tf.int32)[..., tf.newaxis]
            ],
                                         axis=1)
            tree_start_states = tf.nest.map_structure(
                # Alternatively: `lambda v: tf.where(direction, v[1], v[0])`
                lambda v: tf.gather_nd(v, left_right_index),
                initial_step_state)

            directions_expanded = [
                _expand_dims_under_batch_dim(direction,
                                             prefer_static.rank(state))
                for state in tree_start_states.state
            ]
            integrator = leapfrog_impl.SimpleLeapfrogIntegrator(
                self.target_log_prob_fn,
                step_sizes=[
                    tf.where(direction, ss, -ss)
                    for direction, ss in zip(directions_expanded, step_size)
                ],
                num_steps=self.unrolled_leapfrog_steps)

            [
                candidate_tree_state,
                tree_final_states,
                final_not_divergence,
                continue_tree_final,
                energy_diff_tree_sum,
                leapfrogs_taken,
            ] = self._build_sub_tree(
                directions_expanded,
                integrator,
                log_slice_sample,
                init_energy,
                # num_steps_at_this_depth = 2**iter_ = 1 << iter_
                tf.bitwise.left_shift(1, iter_),
                tree_start_states,
                initial_step_metastate.continue_tree,
                initial_step_metastate.not_divergence,
                momentum_state_memory)

            last_candidate_state = initial_step_metastate.candidate_state
            tree_weight = candidate_tree_state.weight
            if MULTINOMIAL_SAMPLE:
                weight_sum = log_add_exp(tree_weight,
                                         last_candidate_state.weight)
                log_accept_thresh = tree_weight - last_candidate_state.weight
            else:
                weight_sum = tree_weight + last_candidate_state.weight
                log_accept_thresh = tf.math.log(
                    tf.cast(tree_weight, tf.float32) /
                    tf.cast(last_candidate_state.weight, tf.float32))
            log_accept_thresh = tf.where(tf.math.is_nan(log_accept_thresh),
                                         tf.zeros([], log_accept_thresh.dtype),
                                         log_accept_thresh)
            u = tf.math.log1p(-tf.random.uniform(shape=[batch_size],
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            choose_new_state = is_sample_accepted & continue_tree_final

            new_candidate_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(choose_new_state,
                                                     prefer_static.rank(s0)),
                        s0, s1) for s0, s1 in zip(candidate_tree_state.state,
                                                  last_candidate_state.state)
                ],
                target=tf.where(choose_new_state, candidate_tree_state.target,
                                last_candidate_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(
                            choose_new_state, prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            candidate_tree_state.target_grad_parts,
                            last_candidate_state.target_grad_parts)
                ],
                weight=weight_sum)
            # Update left right information of the trajectory, and check trajectory
            # level U turn

            # Alternative approach
            # left_right_mask = tf.transpose(
            #     tf.tile(tf.one_hot(tf.cast(direction, tf.int32), 2),
            #            [1, initial_step_metastate.candidate_state[0].shape[-1], 1]),
            #     [2, 0, 1])

            # trajactory_state_left_right = tf.where(
            #     tf.equal(left_right_mask, 0.),
            #     trajactory_state_left_right,
            #     tf.tile(tree_final_states[1][0][tf.newaxis, ...], [2, 1, 1]))
            new_step_state = tf.nest.pack_sequence_as(
                initial_step_state,
                [
                    # Alternative approach:
                    # tf.where(tf.equal(left_right_mask, 0.),
                    #          v,
                    #          tf.tile(r[tf.newaxis],
                    #                  tf.concat([[2], tf.ones_like(tf.shape(r))], 0)))
                    tf.tensor_scatter_nd_update(v, left_right_index, r)
                    for v, r in zip(tf.nest.flatten(initial_step_state),
                                    tf.nest.flatten(tree_final_states))
                ])
            no_u_turns_trajectory = has_not_u_turn(
                [s[0] for s in new_step_state.state],
                [m[0] for m in new_step_state.momentum],
                [s[1] for s in new_step_state.state],
                [m[1] for m in new_step_state.momentum])

            new_step_metastate = TreeDoublingMetaState(
                candidate_state=new_candidate_state,
                is_accepted=choose_new_state
                | initial_step_metastate.is_accepted,
                energy_diff_sum=(energy_diff_tree_sum +
                                 initial_step_metastate.energy_diff_sum),
                continue_tree=continue_tree_final & no_u_turns_trajectory,
                not_divergence=final_not_divergence,
                leapfrog_count=(initial_step_metastate.leapfrog_count +
                                leapfrogs_taken))

            return iter_ + 1, new_step_state, new_step_metastate
예제 #27
0
    def _loop_build_sub_tree(self, directions, integrator, log_slice_sample,
                             init_energy, iter_, energy_diff_sum_previous,
                             leapfrogs_taken, prev_tree_state,
                             candidate_tree_state, continue_tree_previous,
                             not_divergent_previous, momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            # Save state and momentum at odd step, check U turn at even step.
            # Note that here we also write to a Placeholder at even step to avoid
            # using tf.cond
            index = iter_ // 2
            if USE_RAGGED_TENSOR:
                write_index_ = self.write_instruction[index]
            else:
                write_index_ = tf.switch_case(index, self.write_instruction)

            write_index = tf.where(tf.equal(iter_ % 2, 0), write_index_,
                                   self.max_tree_depth)

            if USE_TENSORARRAY:
                momentum_state_memory = MomentumStateSwap(
                    momentum_swap=[
                        old.write(write_index, new) for old, new in zip(
                            momentum_state_memory.momentum_swap,
                            next_momentum_parts)
                    ],
                    state_swap=[
                        old.write(write_index, new) for old, new in zip(
                            momentum_state_memory.state_swap, next_state_parts)
                    ])
            else:
                momentum_state_memory = MomentumStateSwap(
                    momentum_swap=[
                        tf.tensor_scatter_nd_update(old, [[write_index]],
                                                    [new])
                        for old, new in zip(
                            momentum_state_memory.momentum_swap,
                            next_momentum_parts)
                    ],
                    state_swap=[
                        tf.tensor_scatter_nd_update(old, [[write_index]],
                                                    [new]) for old, new in
                        zip(momentum_state_memory.state_swap, next_state_parts)
                    ])
            batch_size = prefer_static.size(next_target)
            has_not_u_turn_at_even_step = tf.ones([batch_size], dtype=tf.bool)

            if USE_RAGGED_TENSOR:
                no_u_turns_within_tree = tf.cond(
                    tf.equal(iter_ % 2, 0),
                    lambda: has_not_u_turn_at_even_step,
                    lambda: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                        self.read_instruction, iter_ // 2, directions,
                        momentum_state_memory, next_momentum_parts,
                        next_state_parts))
            else:
                f = lambda int_iter: has_not_u_turn_at_odd_step(  # pylint: disable=g-long-lambda
                    self.read_instruction, int_iter, directions,
                    momentum_state_memory, next_momentum_parts,
                    next_state_parts)
                branch_excution = {
                    x: functools.partial(f, x)
                    for x in range(len(self.read_instruction))
                }
                no_u_turns_within_tree = tf.cond(
                    tf.equal(iter_ % 2,
                             0), lambda: has_not_u_turn_at_even_step,
                    lambda: tf.switch_case(iter_ // 2, branch_excution))

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            energy = tf.where(tf.math.is_nan(energy),
                              tf.constant(-np.inf, dtype=energy.dtype), energy)
            energy_diff = energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=[batch_size],
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(is_sample_accepted,
                                                     prefer_static.rank(s0)),
                        s0, s1) for s0, s1 in zip(next_state_parts,
                                                  candidate_tree_state.state)
                ],
                target=tf.where(is_sample_accepted, next_target,
                                candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _expand_dims_under_batch_dim(
                            is_sample_accepted, prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                tf.ones([batch_size], dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.clip_by_value(tf.exp(energy_diff), 0., 1.)
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
예제 #28
0
    def _loop_build_sub_tree(self, directions, integrator,
                             current_step_meta_info, iter_,
                             energy_diff_sum_previous,
                             momentum_cumsum_previous, leapfrogs_taken,
                             prev_tree_state, candidate_tree_state,
                             continue_tree_previous, not_divergent_previous,
                             momentum_state_memory):
        """Base case in tree doubling."""
        with tf.name_scope('loop_build_sub_tree'):
            # Take one leapfrog step in the direction v and check divergence
            [
                next_momentum_parts, next_state_parts, next_target,
                next_target_grad_parts
            ] = integrator(prev_tree_state.momentum, prev_tree_state.state,
                           prev_tree_state.target,
                           prev_tree_state.target_grad_parts)

            next_tree_state = TreeDoublingState(
                momentum=next_momentum_parts,
                state=next_state_parts,
                target=next_target,
                target_grad_parts=next_target_grad_parts)
            momentum_cumsum = [
                p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
                                          next_momentum_parts)
            ]
            # If the tree have not yet terminated previously, we count this leapfrog.
            leapfrogs_taken = tf.where(continue_tree_previous,
                                       leapfrogs_taken + 1, leapfrogs_taken)

            write_instruction = current_step_meta_info.write_instruction
            read_instruction = current_step_meta_info.read_instruction
            init_energy = current_step_meta_info.init_energy

            if GENERALIZED_UTURN:
                state_to_write = momentum_cumsum_previous
                state_to_check = momentum_cumsum
            else:
                state_to_write = next_state_parts
                state_to_check = next_state_parts

            batch_shape = prefer_static.shape(next_target)
            has_not_u_turn_init = prefer_static.ones(batch_shape,
                                                     dtype=tf.bool)

            read_index = read_instruction.gather([iter_])[0]
            no_u_turns_within_tree = has_not_u_turn_at_all_index(  # pylint: disable=g-long-lambda
                read_index,
                directions,
                momentum_state_memory,
                next_momentum_parts,
                state_to_check,
                has_not_u_turn_init,
                log_prob_rank=prefer_static.rank(next_target))

            # Get index to write state into memory swap
            write_index = write_instruction.gather([iter_])
            momentum_state_memory = MomentumStateSwap(
                momentum_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.momentum_swap,
                                        next_momentum_parts)
                ],
                state_swap=[
                    tf.tensor_scatter_nd_update(old, [write_index], [new])
                    for old, new in zip(momentum_state_memory.state_swap,
                                        state_to_write)
                ])

            energy = compute_hamiltonian(next_target, next_momentum_parts)
            current_energy = tf.where(tf.math.is_nan(energy),
                                      tf.constant(-np.inf, dtype=energy.dtype),
                                      energy)
            energy_diff = current_energy - init_energy

            if MULTINOMIAL_SAMPLE:
                not_divergent = -energy_diff < self.max_energy_diff
                weight_sum = log_add_exp(candidate_tree_state.weight,
                                         energy_diff)
                log_accept_thresh = energy_diff - weight_sum
            else:
                log_slice_sample = current_step_meta_info.log_slice_sample
                not_divergent = log_slice_sample - energy_diff < self.max_energy_diff
                # Uniform sampling on the trajectory within the subtree across valid
                # samples.
                is_valid = log_slice_sample <= energy_diff
                weight_sum = tf.where(is_valid,
                                      candidate_tree_state.weight + 1,
                                      candidate_tree_state.weight)
                log_accept_thresh = tf.where(
                    is_valid,
                    -tf.math.log(tf.cast(weight_sum, dtype=tf.float32)),
                    tf.constant(-np.inf, dtype=tf.float32))
            u = tf.math.log1p(-tf.random.uniform(shape=batch_shape,
                                                 dtype=log_accept_thresh.dtype,
                                                 seed=self._seed_stream()))
            is_sample_accepted = u <= log_accept_thresh

            next_candidate_tree_state = TreeDoublingStateCandidate(
                state=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(s0)), s0,
                        s1) for s0, s1 in zip(next_state_parts,
                                              candidate_tree_state.state)
                ],
                target=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    next_target, candidate_tree_state.target),
                target_grad_parts=[
                    tf.where(  # pylint: disable=g-complex-comprehension
                        _rightmost_expand_to_rank(is_sample_accepted,
                                                  prefer_static.rank(grad0)),
                        grad0, grad1) for grad0, grad1 in zip(
                            next_target_grad_parts,
                            candidate_tree_state.target_grad_parts)
                ],
                energy=tf.where(
                    _rightmost_expand_to_rank(is_sample_accepted,
                                              prefer_static.rank(next_target)),
                    current_energy, candidate_tree_state.energy),
                weight=weight_sum)

            continue_tree = not_divergent & continue_tree_previous
            continue_tree_next = no_u_turns_within_tree & continue_tree

            not_divergent_tokeep = tf.where(
                continue_tree_previous, not_divergent,
                prefer_static.ones(batch_shape, dtype=tf.bool))

            # min(1., exp(energy_diff)).
            exp_energy_diff = tf.math.exp(tf.minimum(energy_diff, 0.))
            energy_diff_sum = tf.where(
                continue_tree, energy_diff_sum_previous + exp_energy_diff,
                energy_diff_sum_previous)

            return (
                iter_ + 1,
                energy_diff_sum,
                momentum_cumsum,
                leapfrogs_taken,
                next_tree_state,
                next_candidate_tree_state,
                continue_tree_next,
                not_divergent_previous & not_divergent_tokeep,
                momentum_state_memory,
            )
예제 #29
0
    def projection(self, X, y, resp, weights, t_range):

        print("----Tracing___projection")

        @tf.function
        def expec_ll(alpha1, alpha2):  #, i, j, c1, c2):

            temp_lower = tf.identity(self.lower)
            temp_upper = tf.identity(self.upper)

            self.lower.assign(alpha1)
            self.upper.assign(alpha2)

            log_cond = self.expected_ll(X, y, resp, weights)

            self.lower.assign(temp_lower)
            self.upper.assign(temp_upper)

            #tf.print(self.lower)
            #tf.print(self.upper)

            return log_cond

        tmp_indexes = tf.where(
            tf.less(self.no_ovelap_test(), -self.theta / 50.))
        #tf.print("number of remaining overlapping: ", tf.size(tmp_indexes))

        while not (tf.equal(tf.size(tmp_indexes), 0)):
            #print("while  looop")

            classes = tf.cast(tf.math.floordiv(tmp_indexes, self.n_components),
                              tf.int32)
            good_indexes = tf.cast(
                tf.math.floormod(tmp_indexes, self.n_components), tf.int32)

            score = tf.TensorArray(dtype=tf.float32,
                                   size=0,
                                   dynamic_size=True,
                                   name="score",
                                   clear_after_read=False)
            #Matrix of updates
            alpha1 = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True,
                                    name="alpha1",
                                    clear_after_read=False)

            alpha2 = tf.TensorArray(dtype=tf.float32,
                                    size=0,
                                    dynamic_size=True,
                                    name="alpha2",
                                    clear_after_read=False)

            #tf.print(classes)
            #tf.print(good_indexes)

            #For each update, compute the entropy

            for it in tf.range(tf.minimum(tf.constant(self.data_dim), 20)):
                #print("toto")
                d = t_range[it]
                #tf.print(self.lower)
                #print("d loooooop")

                if self.upper[classes[0, 0], good_indexes[0, 0],
                              d] > self.upper[classes[0, 1],
                                              good_indexes[0, 1], d]:

                    alpha1 = alpha1.write(
                        2 * it,
                        tf.tensor_scatter_nd_update(
                            self.lower,
                            [[classes[0, 0], good_indexes[0, 0], d]],
                            [self.upper[classes[0, 1], good_indexes[0, 1], d]
                             ]))

                    alpha2 = alpha2.write(2 * it, self.upper)
                    score = score.write(
                        2 * it,
                        expec_ll(alpha1.read(2 * it), alpha2.read(2 * it)))
                    #,good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1]))

                else:

                    alpha1 = alpha1.write(
                        2 * it,
                        tf.tensor_scatter_nd_update(
                            self.lower,
                            [[classes[0, 1], good_indexes[0, 1], d]],
                            [self.upper[classes[0, 0], good_indexes[0, 0], d]
                             ]))

                    alpha2 = alpha2.write(2 * it, self.upper)

                    score = score.write(
                        2 * it,
                        expec_ll(alpha1.read(2 * it), alpha2.read(2 * it)))
                    #,
                    #good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1]))

                if self.lower[classes[0, 0], good_indexes[0, 0],
                              d] < self.lower[classes[0, 1],
                                              good_indexes[0, 1], d]:

                    alpha2 = alpha2.write(
                        2 * it + 1,
                        tf.tensor_scatter_nd_update(
                            self.upper,
                            [[classes[0, 0], good_indexes[0, 0], d]],
                            [self.lower[classes[0, 1], good_indexes[0, 1], d]
                             ]))

                    alpha1 = alpha1.write(2 * it + 1, self.lower)

                    score = score.write(
                        2 * it + 1,
                        expec_ll(alpha1.read(2 * it + 1),
                                 alpha2.read(2 * it + 1)))
                    #,
                    #   good_indexes[0,0], good_indexes[0,1], classes[0,0], classes [0,1]))

                else:

                    alpha2 = alpha2.write(
                        2 * it + 1,
                        tf.tensor_scatter_nd_update(
                            self.upper,
                            [[classes[0, 1], good_indexes[0, 1], d]],
                            [self.lower[classes[0, 0], good_indexes[0, 0], d]
                             ]))

                    alpha1 = alpha1.write(2 * it + 1, self.lower)

                    score = score.write(
                        2 * it + 1,
                        expec_ll(alpha1.read(2 * it + 1),
                                 alpha2.read(2 * it + 1)))

            #tf.print(score.stack())

            #change the values of alpha corresponding to the lowest update
            true_score = score.stack()
            #ind = tf.cast(tf.math.argmin(tf.boolean_mask(true_score, tf.greater(true_score,0))), tf.int32)
            ind = tf.cast(tf.math.argmin(true_score), tf.int32)
            #tf.print(ind)

            self.lower.assign(alpha1.read(ind))
            self.upper.assign(alpha2.read(ind))

            #Re-compute the no-overlapp
            tmp_indexes = tf.where(tf.less(self.no_ovelap_test(), -self.theta))
            #if not(tf.equal(tf.size(tmp_indexes), 0)):
            #    tf.print(self.no_ovelap_test()[tmp_indexes[0,0], tmp_indexes[0,1]])
            tf.print("number of remaining overlapping: ", tf.size(tmp_indexes))