예제 #1
0
    def _sample_n(self, n, seed=None):
        seed = seed_stream.SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                w = tf1.where(should_continue,
                              (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(
                        tf.random.uniform(sample_batch_shape,
                                          seed=seed(),
                                          dtype=self.dtype)))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.nn.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                            axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.nn.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.nn.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
def _generate_detections(boxes,
                         scores,
                         max_total_size=100,
                         nms_iou_threshold=0.3,
                         score_threshold=0.05,
                         pre_nms_num_boxes=5000):
  """Generate the final detections given the model outputs.

  This uses classes unrolling with while loop based NMS, could be parralled at batch dimension.

  Args:
    boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
      N, 1, 4], which box predictions on all feature levels. The N is the number
      of total anchors on all levels.
    scores: a tensor with shape [batch_size, N, num_classes], which stacks class
      probability on all feature levels. The N is the number of total anchors on
      all levels. The num_classes is the number of classes predicted by the
      model. Note that the class_outputs here is the raw score.
    max_total_size: a scalar representing maximum number of boxes retained over
      all classes.
    nms_iou_threshold: a float representing the threshold for deciding whether
      boxes overlap too much with respect to IOU.
    score_threshold: a float representing the threshold for deciding when to
      remove boxes based on score.
    pre_nms_num_boxes: an int number of top candidate detections per class
      before NMS.

  Returns:
    nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
      representing top detected boxes in [y1, x1, y2, x2].
    nms_scores: `float` Tensor of shape [batch_size, max_total_size]
      representing sorted confidence scores for detected boxes. The values are
      between [0, 1].
    nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
      classes for detected boxes.
    valid_detections: `int` Tensor of shape [batch_size] only the top
      `valid_detections` boxes are valid detections.
  """
  with tf.name_scope('generate_detections'):
    nmsed_boxes = []
    nmsed_classes = []
    nmsed_scores = []
    valid_detections = []
    batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
    num_classes = scores.get_shape().as_list()[2]
    for i in range(num_classes):
      boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
      scores_i = scores[:, :, i]
      # Obtains pre_nms_num_boxes before running NMS.
      scores_i, indices = tf.nn.top_k(
          scores_i,
          k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
      boxes_i = tf.gather(boxes_i, indices, batch_dims=1, axis=1)

      # Filter out scores.
      boxes_i, scores_i = box_utils.filter_boxes_by_scores(
          boxes_i, scores_i, min_score_threshold=score_threshold)

      (nmsed_scores_i, nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
          tf.cast(scores_i, tf.float32),
          tf.cast(boxes_i, tf.float32),
          max_total_size,
          iou_threshold=nms_iou_threshold)
      nmsed_classes_i = tf.fill([batch_size, max_total_size], i)
      nmsed_boxes.append(nmsed_boxes_i)
      nmsed_scores.append(nmsed_scores_i)
      nmsed_classes.append(nmsed_classes_i)
  nmsed_boxes = tf.concat(nmsed_boxes, axis=1)
  nmsed_scores = tf.concat(nmsed_scores, axis=1)
  nmsed_classes = tf.concat(nmsed_classes, axis=1)
  nmsed_scores, indices = tf.nn.top_k(
      nmsed_scores, k=max_total_size, sorted=True)
  nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
  nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
  valid_detections = tf.reduce_sum(
      input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32), axis=1)
  return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
예제 #3
0
def sample_and_preprocess(video,
                          labels,
                          seq_label,
                          seq_len,
                          name,
                          num_steps,
                          augment,
                          sample_all=False,
                          sample_all_stride=1,
                          add_shape=False):
    """Samples frames and prepares them for training."""

    if sample_all:
        # When dealing with very long videos we can choose to sub-sample to fit
        # data in memory. But be aware this also evaluates over a subset of frames.
        # Subsampling the validation set videos when reporting performance is not
        # recommended.
        steps = tf.range(0, seq_len, sample_all_stride)
        seq_len = tf.shape(steps)[0]
        chosen_steps = steps
    else:
        stride = CONFIG.DATA.STRIDE
        sampling_strategy = CONFIG.DATA.SAMPLING_STRATEGY

        # TODO(debidatta) : More flexible sampling
        if sampling_strategy == 'stride':
            # Offset can be set between 0 and maximum location from which we can get
            # total coverage of the video without having to pad.
            # This handles sampling over longer sequences.
            offset = tf.random.uniform(
                (),
                0,
                tf.maximum(tf.cast(1, tf.int64), seq_len - stride * num_steps),
                dtype=tf.int64)
            # This handles sampling over shorter sequences by padding the last frame
            # many times. This is not ideal for the way alignment training batches are
            # created.
            steps = tf.minimum(
                seq_len - 1,
                tf.range(offset, offset + num_steps * stride + 1, stride))
            steps = steps[:num_steps]
        elif sampling_strategy == 'offset_uniform':
            # Sample a random offset less than a provided max offset. Among all frames
            # higher than the chosen offset, randomly sample num_frames
            check1 = tf.debugging.assert_greater_equal(
                seq_len,
                tf.cast(CONFIG.DATA.RANDOM_OFFSET, tf.int64),
                message='Random offset is more than sequence length.')
            check2 = tf.less_equal(
                tf.cast(num_steps, tf.int64),
                seq_len - tf.cast(CONFIG.DATA.RANDOM_OFFSET, tf.int64),
            )

            def _sample_random():
                with tf.control_dependencies([tf.identity(check1.outputs[0])]):
                    offset = CONFIG.DATA.RANDOM_OFFSET
                    steps = tf.random.shuffle(tf.range(offset, seq_len))
                    steps = tf.gather(steps, tf.range(0, num_steps))
                    steps = tf.gather(
                        steps,
                        tf.nn.top_k(steps, k=num_steps).indices[::-1])
                    return steps

            def _sample_all():
                return tf.range(0, num_steps, dtype=tf.int64)

            steps = tf.cond(check2, _sample_random, _sample_all)

        else:
            raise ValueError(
                'Sampling strategy %s is unknown. Supported values are '
                'stride, offset_uniform .' % sampling_strategy)

        if not sample_all and 'tcn' in CONFIG.TRAINING_ALGO:
            pos_window = CONFIG.TCN.POSITIVE_WINDOW
            # pylint: disable=g-long-lambda
            pos_steps = tf.map_fn(
                lambda step: tf.random.uniform(
                    (), minval=step - pos_window, maxval=step, dtype=tf.int64),
                steps)
            # pylint: enable=g-long-lambda
            steps = tf.stack([pos_steps, steps])
            steps = tf.reshape(tf.transpose(steps), (-1, ))

        # Store chosen indices.
        chosen_steps = steps
        # Get multiple context steps depending on config at selected steps.
        steps = tf.reshape(tf.map_fn(get_steps, steps), [-1])
        steps = tf.maximum(tf.cast(0, tf.int64), steps)
        steps = tf.minimum(seq_len - 1, steps)

    shape_all_steps = CONFIG.DATA.NUM_STEPS * num_steps
    if not sample_all and 'tcn' in CONFIG.TRAINING_ALGO:
        shape_all_steps *= 2

    # Select data based on steps/
    video = tf.gather(video, steps)
    # Decode the encoded JPEG images
    video = tf.map_fn(tf.image.decode_jpeg,
                      video,
                      parallel_iterations=FLAGS.num_parallel_calls,
                      dtype=tf.uint8)
    # Take images in range [0, 255] and normalize to [0, 1]
    video = tf.map_fn(normalize_input,
                      video,
                      parallel_iterations=FLAGS.num_parallel_calls,
                      dtype=tf.float32)
    # Perform data-augmentation and return images in range [-1, 1]
    video = preprocess_input(video, augment)
    if add_shape:
        video.set_shape(
            [shape_all_steps, CONFIG.IMAGE_SIZE, CONFIG.IMAGE_SIZE, 3])

    if CONFIG.DATA.FRAME_LABELS:
        labels = tf.gather(labels, steps)
        if add_shape:
            labels.set_shape([shape_all_steps])

    return {
        'frames': video,
        'frame_labels': labels,
        'chosen_steps': chosen_steps,
        'seq_lens': seq_len,
        'seq_labels': seq_label,
        'name': name
    }
예제 #4
0
def count_integers(arr,
                   weights=None,
                   minlength=None,
                   maxlength=None,
                   axis=None,
                   dtype=tf.int32,
                   name=None):
    """Counts the number of occurrences of each value in an integer array `arr`.

  Works like `tf.math.bincount`, but provides an `axis` kwarg that specifies
  dimensions to reduce over.  With
    `~axis = [i for i in range(arr.ndim) if i not in axis]`,
  this function returns a `Tensor` of shape `[K] + arr.shape[~axis]`.

  If `minlength` and `maxlength` are not given, `K = tf.reduce_max(arr) + 1`
  if `arr` is non-empty, and 0 otherwise.
  If `weights` are non-None, then index `i` of the output stores the sum of the
  value in `weights` at each index where the corresponding value in `arr` is
  `i`.

  Args:
    arr: An `int32` `Tensor` of non-negative values.
    weights: If non-None, must be the same shape as arr. For each value in
      `arr`, the bin will be incremented by the corresponding weight instead of
      1.
    minlength: If given, ensures the output has length at least `minlength`,
      padding with zeros at the end if necessary.
    maxlength: If given, skips values in `arr` that are equal or greater than
      `maxlength`, ensuring that the output has length at most `maxlength`.
    axis: A `0-D` or `1-D` `int32` `Tensor` (with static values) designating
      dimensions in `arr` to reduce over.
      `Default value:` `None`, meaning reduce over all dimensions.
    dtype: If `weights` is None, determines the type of the output bins.
    name: A name scope for the associated operations (optional).

  Returns:
    A vector with the same dtype as `weights` or the given `dtype`. The bin
    values.
  """
    with tf.name_scope(name or 'count_integers'):
        if axis is None:
            return tf.math.bincount(arr,
                                    weights=weights,
                                    minlength=minlength,
                                    maxlength=maxlength,
                                    dtype=dtype)

        arr = tf.convert_to_tensor(arr, dtype=tf.int32, name='arr')
        arr_ndims = _get_static_ndims(arr, expect_static=True)

        axis = _make_static_axis_non_negative_list(axis, arr_ndims)

        # ~axis from docstring.  Dims in arr that are not in axis.
        not_axis = sorted(set(range(arr_ndims)).difference(axis))

        # If we're reducing over everything, just use standard bincount.
        if not not_axis:
            return tf.math.bincount(arr,
                                    weights=weights,
                                    minlength=minlength,
                                    maxlength=maxlength,
                                    dtype=dtype)

        # Move dims in ~axis to the left, so we can tf.map_fn bincount over them,
        # Producing counts for every index I in ~axis.
        # Thus, flat_arr is not totally flat, it just has the dims in ~axis
        # flattened.
        flat_arr = _move_dims_to_flat_end(arr,
                                          not_axis,
                                          arr_ndims,
                                          right_end=False)
        minlength = minlength if minlength is not None else tf.reduce_max(
            arr) + 1
        maxlength = maxlength if maxlength is not None else tf.reduce_max(
            arr) + 1

        # tf.map_fn over dim 0.
        if weights is None:

            def one_bincount(arr_slice):
                return tf.math.bincount(arr_slice,
                                        weights=None,
                                        minlength=minlength,
                                        maxlength=maxlength,
                                        dtype=dtype)

            flat_counts = tf.map_fn(one_bincount,
                                    elems=flat_arr,
                                    fn_output_signature=dtype)
        else:
            weights = tf.convert_to_tensor(weights, name='weights')
            _get_static_ndims(weights,
                              expect_static=True,
                              expect_ndims=arr_ndims)
            flat_weights = _move_dims_to_flat_end(weights,
                                                  not_axis,
                                                  arr_ndims,
                                                  right_end=False)

            def one_bincount(arr_and_weights_slices):
                arr_slice, weights_slice = arr_and_weights_slices
                return tf.math.bincount(arr_slice,
                                        weights=weights_slice,
                                        minlength=minlength,
                                        maxlength=maxlength,
                                        dtype=dtype)

            flat_counts = tf.map_fn(one_bincount,
                                    elems=[flat_arr, flat_weights],
                                    fn_output_signature=weights.dtype)

        # flat_counts.shape = [prod(~axis), K], because map_fn stacked on axis 0.
        # bincount needs to have the K bins in axis 0, so transpose...
        flat_counts_t = tf.transpose(a=flat_counts, perm=[1, 0])

        # Throw in this assert, to ensure shape assumptions are correct.
        _get_static_ndims(flat_counts_t, expect_ndims=2, expect_static=True)

        # not_axis_shape = arr.shape[~axis]
        not_axis_shape = tf.gather(tf.shape(arr), indices=not_axis)

        # The first index of flat_counts_t indexes bins 0,..,K-1, the rest are ~axis
        out_shape = tf.concat([[-1], not_axis_shape], axis=0)

        return tf.reshape(flat_counts_t, out_shape)
def add_entity_tokens(
    text_ids: tf.Tensor,
    text_mask: tf.Tensor,
    mention_mask: tf.Tensor,
    mention_batch_positions: tf.Tensor,
    mention_start_positions: tf.Tensor,
    mention_end_positions: tf.Tensor,
    new_length: int,
    mlm_target_positions: Optional[tf.Tensor] = None,
    mlm_target_weights: Optional[tf.Tensor] = None,
    entity_start_token_id: int = default_values.ENTITY_START_TOKEN,
    entity_end_token_id: int = default_values.ENTITY_END_TOKEN,
) -> Dict[str, tf.Tensor]:
    """Adds entity start / end tokens around mentions.

  Inserts `entity_start_token_id` and `entity_end_token_id` tokens around each
  mention and update mention_start_positions / mention_end_positions to point
  to these tokens.

  New text length will be `new_length` and texts will be truncated if nessesary.
  If a mention no longer fits into the new text, its mask (`mention_mask`) will
  be set to 0.

  The function can also update MLM position and weights (`mlm_target_positions`
  and `mlm_target_weights`) if these arguments are provided. Similarly to
  mentions, if MLM position no longer fits into the new text, its mask
  (`mlm_target_weights`) will be set to 0.

  Args:
    text_ids: [seq_length] tensor with token ids.
    text_mask: [seq_length] tensor with 1s for tokens and 0 for padding.
    mention_mask: [n_mentions] mask indicating whether a mention is a padding.
    mention_batch_positions: [n_mentions] sample ID of a mention in the batch.
    mention_start_positions: [n_mentions] position of a mention first token
      within a sample.
    mention_end_positions: [n_mentions] position of a mention last token within
      a sample.
    new_length: new length of text after entity tokens will be added.
    mlm_target_positions: [batch_size, max_mlm_targets] positions of tokens to
      be used for MLM task.
    mlm_target_weights: [batch_size, max_mlm_targets] mask indicating whether
      `mlm_target_positions` is a padding.
    entity_start_token_id: token to be used as entity start token.
    entity_end_token_id: token to be used as entity end token.

  Returns:
    New text_ids and text_mask, updated mentions positions including
    mention_start_positions, mention_end_positions and mention_mask.
    Returns updated mlm_target_positions and mlm_target_weights if they were
    provided as arguments.
  """
    batch_size = tf.shape(text_ids)[0]
    old_length = tf.shape(text_ids)[1]
    new_shape = (batch_size, new_length)

    mentions_fit_mask = compute_which_mentions_fit_with_entity_tokens(
        mention_mask,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        batch_size,
        old_length,
        new_length,
    )
    # Ignore mentions that does not fit into new texts.
    new_mention_mask = mention_mask * mentions_fit_mask
    mention_start_positions = mention_start_positions * new_mention_mask
    mention_end_positions = mention_end_positions * new_mention_mask

    positions = compute_positions_shift_with_entity_tokens(
        new_mention_mask, mention_batch_positions, mention_start_positions,
        mention_end_positions, batch_size, old_length)

    def get_2d_index(positions: tf.Tensor) -> tf.Tensor:
        return _get_2d_index(mention_batch_positions, positions)

    def get_new_positions(old_positions: tf.Tensor) -> tf.Tensor:
        index_2d = get_2d_index(old_positions)
        return tf.gather_nd(positions, index_2d)

    new_mention_start_positions = get_new_positions(
        mention_start_positions) - 1
    new_mention_start_positions = new_mention_start_positions * new_mention_mask
    new_mention_end_positions = get_new_positions(mention_end_positions) + 1
    new_mention_end_positions = new_mention_end_positions * new_mention_mask

    if mlm_target_positions is not None:
        if mlm_target_weights is None:
            raise ValueError('`mlm_target_weights` must be specified if '
                             '`mlm_target_positions` is provided.')
        mlm_target_positions = tf.gather(positions,
                                         mlm_target_positions,
                                         batch_dims=1)
        mlm_target_positions_within_len = tf.less(mlm_target_positions,
                                                  new_length)
        mlm_target_positions_within_len = tf.cast(
            mlm_target_positions_within_len, mlm_target_weights.dtype)
        mlm_target_weights = mlm_target_weights * mlm_target_positions_within_len
        # Zero-out positions for pad MLM targets
        mlm_target_positions = mlm_target_positions * mlm_target_weights

    # Cut texts that are longer than `new_length`
    text_within_new_length = tf.less(positions, new_length)
    text_ids = text_ids * tf.cast(text_within_new_length, text_ids.dtype)
    text_mask = text_mask * tf.cast(text_within_new_length, text_mask.dtype)
    positions = tf.minimum(positions, new_length - 1)

    # Prepare 2D index for tokens positions in the next text_ids and text_mask.
    # Note that we use flat 2D index and flat values
    # (e.g. `tf.reshape(text_ids, [-1])`) since `tf.scatter_nd` does not support
    # batch dimension.
    batch_positions = _batched_range(old_length, batch_size, 1,
                                     positions.dtype)
    batch_positions = tf.reshape(batch_positions, [-1])
    text_index_2d = _get_2d_index(batch_positions, tf.reshape(positions, [-1]))

    new_text_ids = tf.scatter_nd(text_index_2d, tf.reshape(text_ids, [-1]),
                                 new_shape)
    new_text_mask = tf.scatter_nd(text_index_2d, tf.reshape(text_mask, [-1]),
                                  new_shape)

    # Insert entity start / end tokens into the new text_ids and text_mask.
    new_mention_start_positions_2d = get_2d_index(new_mention_start_positions)
    new_mention_end_positions_2d = get_2d_index(new_mention_end_positions)

    new_text_ids = tf.tensor_scatter_nd_add(
        new_text_ids, new_mention_start_positions_2d,
        new_mention_mask * entity_start_token_id)
    new_text_ids = tf.tensor_scatter_nd_add(
        new_text_ids, new_mention_end_positions_2d,
        new_mention_mask * entity_end_token_id)

    new_mention_mask = tf.cast(new_mention_mask, dtype=text_mask.dtype)
    new_text_mask = tf.tensor_scatter_nd_add(new_text_mask,
                                             new_mention_start_positions_2d,
                                             new_mention_mask)
    new_text_mask = tf.tensor_scatter_nd_add(new_text_mask,
                                             new_mention_end_positions_2d,
                                             new_mention_mask)

    features = {
        'text_ids': new_text_ids,
        'text_mask': new_text_mask,
        'mention_start_positions': new_mention_start_positions,
        'mention_end_positions': new_mention_end_positions,
        'mention_mask': new_mention_mask,
    }

    if mlm_target_positions is not None:
        features['mlm_target_weights'] = mlm_target_weights
        features['mlm_target_positions'] = mlm_target_positions

    return features
예제 #6
0
def _orthogonal_complement_e_i(vectors, i, gram_schmidt_iters):
    """Computes a basis for the orthogonal complement to `e_i` in `span(vectors)`.

  The orthogonal complement of the coordinate vector `e_i` of the vector space
  `V` is the set of all vectors in `V` that are orthogonal to `e_i`.

  We compute this by first choosing a column `j` of `vectors` with non-zero in
  coordinate `i`. This vector (`col_j`) is subtracted from all other vectors
  with an appropriate weight to zero out row `i`. Finally, we orthonormalize
  using (modified) Gram-Schmidt. For performance reasons, the calling code
  specifies the G-S iteration count.

  For example, suppose we start with the matrix of column vectors:

  ```none
  [ 2  4  7 ]
  [ 4  2  4 ]
  [ 6  6  3 ]
  ```

  If we suppose `i = 1`, we are being asked to zero-out the middle row, i.e.
  orthogonalize with respect to the coordinate vector `e_1 = [0, 1, 0]^T`. We
  can do so by picking `j = argmax(mat[i, :])`, so `j = 0` in this case. Then,
  compute the appropriate weights that would zero out the row, i.e.
  `w=[1, 0.5, 1]` and subtract `mat[:, j:j+1] * w = [2, 4, 6]^T * [1, .5, 1]`.
  This yields the intermediate:

  ```none
  [ 2  4  7 ]   [ 2  1  2 ]   [ 0  3  5 ]
  [ 4  2  4 ] - [ 4  2  4 ] = [ 0  0  0 ]
  [ 6  6  3 ]   [ 6  3  6 ]   [ 0  3 -3 ]
  ```

  We rotate the zero column to the end, and finally return the result of
  applying Gram-Schmidt orthogonalization, i.e.

  ```none
  [ sqrt(.5)  sqrt(.5) 0 ]
  [     0        0     0 ]
  [ sqrt(.5) -sqrt(.5) 0 ]
  ```

  Args:
    vectors: A Tensor of vectors of shape `[..., d, n]` we are orthogonalizing.
    i: The coordinate (against dimension `d`) w.r.t. which we orthogonalize.
    gram_schmidt_iters: Number of iterations of Gram-Schmidt orthonormalization
      to run, generally `n_vectors - iter_num`. Since each iteration of sampling
      reduces the number of nonzero columns by one (in the `n` dim), this allows
      us to save iterations of orthonormalization work.

  Returns:
    orthogonal: A Tensor of shape `[..., d, n]` representing the subspace
      spanned by `vectors` that is orthogonal to `e_i`, the `i`-th coordinate
      vector. The tensor is orthonormalized. It contains at least one more zero
      row (`i`) and zero column than the input vectors (exactly one more if all
      nonzero columns of `vectors` are linearly independent).
  """
    i = tf.convert_to_tensor(i, dtype_hint=tf.int32)
    row_i = tf.gather(vectors, i, axis=-2, batch_dims=len(i.shape))
    j = tf.argmax(tf.abs(row_i), axis=-1)  # Max for numerical stability.
    col_j = tf.gather(vectors, j, axis=-1, batch_dims=len(j.shape))
    val_i_j = tf.gather(row_i, j, axis=-1, batch_dims=len(j.shape))
    weights = row_i / val_i_j[..., tf.newaxis]
    delta = weights[..., tf.newaxis, :] * col_j[..., :, tf.newaxis]
    result = (vectors - delta)
    # Rotate the new zero column to the end.
    d = ps.shape(vectors)[-2]
    n = ps.shape(vectors)[-1]
    mask_d = tf.not_equal(tf.range(d, dtype=i.dtype),
                          i[..., tf.newaxis])[..., tf.newaxis]
    shift_indices = tf.range(n, dtype=j.dtype)
    shift_indices = shift_indices + tf.cast(
        shift_indices >= j[..., tf.newaxis], j.dtype)
    shift_indices = tf.where(shift_indices >= tf.cast(n, j.dtype),
                             j[..., tf.newaxis], shift_indices)
    result = tf.gather(result,
                       shift_indices,
                       axis=-1,
                       batch_dims=len(shift_indices.shape) - 1)
    mask_n = tf.not_equal(tf.range(n), n - 1)
    result = tf.where(mask_d & mask_n, result, 0)  # Make exactly zero.
    # Orthonormalize. This is equivalent, but faster than tf.linalg.qr(result).q
    return tfp_math.gram_schmidt(result, gram_schmidt_iters)
    def __init__(self,
                 background_vertices,
                 background_attributes,
                 background_triangles,
                 camera_origin,
                 look_at,
                 camera_up,
                 field_of_view,
                 image_size,
                 near_plane,
                 far_plane,
                 bottom_left=(0.0, 0.0),
                 name=None):
        """Initializes TriangleRasterizer with OpenGL parameters and the background.

    Note:
      In the following, A1 to An are optional batch dimensions.

    Args:
      background_vertices: A tensor of shape `[V, 3]` containing `V` 3D
        vertices. Note that these background vertices will be used in every
        rasterized image.
      background_attributes: A tensor of shape `[V, K]` containing `V` vertices
        associated with K-dimensional attributes. Pixels for which the first
        visible surface is in the background geometry will make use of
        `background_attribute` for estimating their own attribute. Note that
        these background attributes will be use in every rasterized image.
      background_triangles: An integer tensor of shape `[T, 3]` containing `T`
        triangles, each associated with 3 vertices from `background_vertices`.
        Note that these background triangles will be used in every rasterized
        image.
      camera_origin: A Tensor of shape `[A1, ..., An, 3]`, where the last axis
        represents the 3D position of the camera.
      look_at: A Tensor of shape `[A1, ..., An, 3]`, with the last axis storing
        the position where the camera is looking at.
      camera_up: A Tensor of shape `[A1, ..., An, 3]`, where the last axis
        defines the up vector of the camera.
      field_of_view:  A Tensor of shape `[A1, ..., An, 1]`, where the last axis
        represents the vertical field of view of the frustum expressed in
        radians. Note that values for `field_of_view` must be in the range (0,
        pi).
      image_size: A tuple (height, width) containing the dimensions in pixels of
        the rasterized image".
      near_plane: A Tensor of shape `[A1, ..., An, 1]`, where the last axis
        captures the distance between the viewer and the near clipping plane.
        Note that values for `near_plane` must be non-negative.
      far_plane: A Tensor of shape `[A1, ..., An, 1]`, where the last axis
        captures the distance between the viewer and the far clipping plane.
        Note that values for `far_plane` must be non-negative.
      bottom_left: A Tensor of shape `[A1, ..., An, 2]`, where the last axis
        captures the position (in pixels) of the lower left corner of the
        screen. Defaults to (0.0, 0.0).
      name: A name for this op. Defaults to 'triangle_rasterizer_init'.
    """
        with tf.compat.v1.name_scope(
                name, "triangle_rasterizer_init",
            (background_vertices, background_attributes, background_triangles,
             camera_origin, look_at, camera_up, field_of_view, near_plane,
             far_plane, bottom_left)):

            background_vertices = tf.convert_to_tensor(
                value=background_vertices)
            background_attributes = tf.convert_to_tensor(
                value=background_attributes)
            background_triangles = tf.convert_to_tensor(
                value=background_triangles)

            shape.check_static(tensor=background_vertices,
                               tensor_name="background_vertices",
                               has_rank=2,
                               has_dim_equals=(-1, 3))
            shape.check_static(tensor=background_attributes,
                               tensor_name="background_attributes",
                               has_rank=2)
            shape.check_static(
                tensor=background_triangles,
                tensor_name="background_triangles",
                # has_rank=2,
                has_dim_equals=(-1, 3))
            shape.compare_batch_dimensions(
                tensors=(background_vertices, background_attributes),
                last_axes=-2,
                tensor_names=("background_geometry", "background_attribute"),
                broadcast_compatible=False)

            background_vertices = tf.expand_dims(background_vertices, axis=0)
            background_attributes = tf.expand_dims(background_attributes,
                                                   axis=0)

            height = float(image_size[0])
            width = float(image_size[1])

            self._background_geometry = tf.gather(background_vertices,
                                                  background_triangles,
                                                  axis=-2)
            self._background_attribute = tf.gather(background_attributes,
                                                   background_triangles,
                                                   axis=-2)

            self._camera_origin = tf.convert_to_tensor(value=camera_origin)
            self._look_at = tf.convert_to_tensor(value=look_at)
            self._camera_up = tf.convert_to_tensor(value=camera_up)
            self._field_of_view = tf.convert_to_tensor(value=field_of_view)
            self._image_size_glm = tf.convert_to_tensor(value=(width, height))
            self._image_size_int = (int(width), int(height))
            self._near_plane = tf.convert_to_tensor(value=near_plane)
            self._far_plane = tf.convert_to_tensor(value=far_plane)
            self._bottom_left = tf.convert_to_tensor(value=bottom_left)

            # Construct the pixel grid. Note that OpenGL uses half-integer pixel
            # centers.
            px = tf.linspace(0.5, width - 0.5, num=int(width))
            py = tf.linspace(0.5, height - 0.5, num=int(height))
            xv, yv = tf.meshgrid(px, py)
            self._pixel_position = tf.stack((xv, yv), axis=-1)

            # Construct the view projection matrix.
            world_to_camera = glm.look_at_right_handed(camera_origin, look_at,
                                                       camera_up)
            perspective_matrix = glm.perspective_right_handed(
                field_of_view, (width / height, ), near_plane, far_plane)
            perspective_matrix = tf.squeeze(perspective_matrix)
            self._view_projection_matrix = tf.linalg.matmul(
                perspective_matrix, world_to_camera)
def mask_mentions_and_tokens_tf(
    text_ids: tf.Tensor,
    text_mask: tf.Tensor,
    dense_span_starts: tf.Tensor,
    dense_span_ends: tf.Tensor,
    non_mention_mask_rate: float,
    mention_mask_rate: float,
    max_mlm_targets: int,
    mask_token_id: int,
    vocab_size: int,
    random_replacement_prob: float = 0.1,
    identity_replacement_prob: float = 0.1,
) -> Dict[str, tf.Tensor]:
    """Randomly masks whole mentions and random tokens up to a maximum.

  First, mentions are masked according to mention mask rate. If a mention is
  masked, all tokens in the mention are replaced by the mask token. If the
  passage has at least one mention and the mention rask rate is greater than
  zero, we mask at least one mention.

  After masking mentions, if there are fewer masked tokens than maximum mlm
  targets, we additionally mask non-mention words. TODO: If a token in a word
  is masked, all tokens in the word are masked. Some proportion of targets are
  not masked to ameliorate pretrain-finetune mismatch. If there are insufficient
  masked tokens, the target array is padded up to max targets.

  Args:
    text_ids: [seq_length] tensor with token ids.
    text_mask: [seq_length] tensor with 1s for tokens and 0 for padding.
    dense_span_starts: [seq_length] tensor with 1s for mention start positions
      and 0 otherwise.
    dense_span_ends: [seq_length] tensor with 1s for mention end positions and 0
      otherwise.
    non_mention_mask_rate: percentage of non mention tokens to be masked.
    mention_mask_rate: percentage of mentions to be masked.
    max_mlm_targets: total number of mlm targets.
    mask_token_id: token id for mask token.
    vocab_size: vocabulary size.
    random_replacement_prob: probability that to-be-masked token will be
      replaced with a random token instead of [MASK].
    identity_replacement_prob: probability that to-be-masked token will be
      replaced with itself instead of [MASK].

  Returns:
    Dictionary with masked text, mask positions, target ids, target weights.
  """
    # Mask mentions
    mention_start_positions = non_zero_1d(dense_span_starts)
    mention_end_positions = non_zero_1d(dense_span_ends)
    mention_masked_positions = mask_tokens_by_spans(text_ids,
                                                    mention_start_positions,
                                                    mention_end_positions,
                                                    mention_mask_rate,
                                                    max_mlm_targets)

    dense_is_mention = get_dense_is_inside_for_dense_spans(
        dense_span_starts, dense_span_ends)
    dense_is_not_mention = 1 - dense_is_mention
    dense_is_not_mention = dense_is_not_mention * text_mask

    # Mask non-mentions
    non_mention_start_positions = non_zero_1d(dense_is_not_mention)
    # TODO(urikz): Implement whole-word masking
    non_mention_end_positions = non_mention_start_positions
    non_mention_masked_positions = mask_tokens_by_spans(
        text_ids, non_mention_start_positions, non_mention_end_positions,
        non_mention_mask_rate,
        max_mlm_targets - tf.shape(mention_masked_positions)[0])

    # Merge masked positions for mention and non-mention tokens
    mlm_target_positions = tf.concat(
        [mention_masked_positions, non_mention_masked_positions], -1)
    n_mlm_target_positions = tf.shape(mlm_target_positions)

    # Get target IDs, weights and other features
    mlm_target_ids = tf.gather(text_ids, mlm_target_positions)
    mlm_target_weights = tf.ones(n_mlm_target_positions, dtype=tf.int64)
    mlm_target_is_mention = tf.ones(tf.shape(mention_masked_positions),
                                    dtype=tf.int64)
    seq_length = tf.shape(text_ids)[0]
    dense_is_masked = sparse_to_dense_1d(mlm_target_positions, seq_length)

    # Replace masked tokens with [MASK], random or original tokens.
    replacement_scores = tf.random.uniform(n_mlm_target_positions)
    replacement_tokens = tf.where(
        replacement_scores >
        random_replacement_prob + identity_replacement_prob,
        # replace tokens with [MASK]
        tf.cast(tf.fill(n_mlm_target_positions, value=mask_token_id),
                dtype=tf.int64),
        tf.where(
            replacement_scores > random_replacement_prob,
            # keep original
            mlm_target_ids,
            # replace with random token
            tf.random.uniform(n_mlm_target_positions,
                              maxval=vocab_size,
                              dtype=tf.int64)))
    replacement_positions = tf.expand_dims(mlm_target_positions, 1)
    # Indicies should be tf.int32 only.
    replacement_positions = tf.cast(replacement_positions, tf.int32)
    replacement_tokens = tf.scatter_nd(replacement_positions,
                                       replacement_tokens, tf.shape(text_ids))
    masked_text_ids = (text_ids * (1 - dense_is_masked) +
                       replacement_tokens * dense_is_masked)

    return {
        'masked_text_ids':
        masked_text_ids,
        'mlm_target_positions':
        dynamic_padding_1d(mlm_target_positions, max_mlm_targets),
        'mlm_target_ids':
        dynamic_padding_1d(mlm_target_ids, max_mlm_targets),
        'mlm_target_weights':
        dynamic_padding_1d(mlm_target_weights, max_mlm_targets),
        'mlm_target_is_mention':
        dynamic_padding_1d(mlm_target_is_mention, max_mlm_targets),
        'dense_is_masked':
        dense_is_masked,
    }
예제 #9
0
def lu_reconstruct(lower_upper, perm, validate_args=False, name=None):
  """The inverse LU decomposition, `X == lu_reconstruct(*tf.linalg.lu(X))`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_reconstruct').

  Returns:
    x: The original input to `tf.linalg.lu`, i.e., `x` as in,
      `lu_reconstruct(*tf.linalg.lu(x))`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[3., 4], [1, 2]],
       [[7., 8], [3, 4]]]
  x_reconstructed = tfp.math.lu_reconstruct(*tf.linalg.lu(x))
  tf.assert_near(x, x_reconstructed)
  # ==> True
  ```

  """
  with tf.name_scope(name or 'lu_reconstruct'):
    lower_upper = tf.convert_to_tensor(
        lower_upper, dtype_hint=tf.float32, name='lower_upper')
    perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')

    assertions = _lu_reconstruct_assertions(lower_upper, perm, validate_args)
    if assertions:
      with tf.control_dependencies(assertions):
        lower_upper = tf.identity(lower_upper)
        perm = tf.identity(perm)

    shape = tf.shape(lower_upper)

    lower = tf.linalg.set_diag(
        tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
        tf.ones(shape[:-1], dtype=lower_upper.dtype))
    upper = tf.linalg.band_part(lower_upper, num_lower=0, num_upper=-1)
    x = tf.matmul(lower, upper)

    if lower_upper.shape.ndims is None or lower_upper.shape.ndims != 2:
      # We either don't know the batch rank or there are >0 batch dims.
      batch_size = tf.reduce_prod(shape[:-2])
      d = shape[-1]
      x = tf.reshape(x, [batch_size, d, d])
      perm = tf.reshape(perm, [batch_size, d])
      perm = tf.map_fn(tf.math.invert_permutation, perm)
      batch_indices = tf.broadcast_to(
          tf.range(batch_size)[:, tf.newaxis],
          [batch_size, d])
      x = tf.gather_nd(x, tf.stack([batch_indices, perm], axis=-1))
      x = tf.reshape(x, shape)
    else:
      x = tf.gather(x, tf.math.invert_permutation(perm))

    x.set_shape(lower_upper.shape)
    return x
예제 #10
0
def compute_stochastic_alignment_loss(embs, steps, seq_lens, num_steps,
                                      batch_size, loss_type, similarity_type,
                                      num_cycles, cycle_length, temperature,
                                      label_smoothing, variance_lambda,
                                      huber_delta, normalize_indices):
    """Compute cycle-consistency loss by stochastically sampling cycles.

  Args:
    embs: Tensor, sequential embeddings of the shape [N, T, D] where N is the
      batch size, T is the number of timesteps in the sequence, D is the size of
      the embeddings.
    steps: Tensor, step indices/frame indices of the embeddings of the shape
      [N, T] where N is the batch size, T is the number of the timesteps.
    seq_lens: Tensor, Lengths of the sequences from which the sampling was done.
      This can provide additional information to the alignment loss.
    num_steps: Integer/Tensor, Number of timesteps in the embeddings.
    batch_size: Integer/Tensor, Batch size.
    loss_type: String, This specifies the kind of loss function to use.
      Currently supported loss functions: 'classification', 'regression_mse',
      'regression_mse_var', 'regression_huber'.
    similarity_type: String, Currently supported similarity metrics: 'l2',
      'cosine'.
    num_cycles: Integer, number of cycles to match while aligning
      stochastically.  Only used in the stochastic version.
    cycle_length: Integer, Lengths of the cycle to use for matching. Only used
      in the stochastic version. By default, this is set to 2.
    temperature: Float, temperature scaling used to scale the similarity
      distributions calculated using the softmax function.
    label_smoothing: Float, Label smoothing argument used in
      tf.keras.losses.categorical_crossentropy function and described in this
      paper https://arxiv.org/pdf/1701.06548.pdf.
    variance_lambda: Float, Weight of the variance of the similarity
      predictions while cycling back. If this is high then the low variance
      similarities are preferred by the loss while making this term low results
      in high variance of the similarities (more uniform/random matching).
    huber_delta: float, Huber delta described in tf.keras.losses.huber_loss.
    normalize_indices: Boolean, If True, normalizes indices by sequence lengths.
      Useful for ensuring numerical instabilities doesn't arise as sequence
      indices can be large numbers.

  Returns:
    loss: Tensor, Scalar loss tensor that imposes the chosen variant of the
      cycle-consistency loss.
  """
    # Generate cycles.
    cycles = gen_cycles(num_cycles, batch_size, cycle_length)

    logits, labels = _align(cycles, embs, num_steps, num_cycles, cycle_length,
                            similarity_type, temperature)

    if loss_type == 'classification':
        loss = classification_loss(logits, labels, label_smoothing)
    elif 'regression' in loss_type:
        steps = tf.gather(steps, cycles[:, 0])
        seq_lens = tf.gather(seq_lens, cycles[:, 0])
        loss = regression_loss(logits, labels, num_steps, steps, seq_lens,
                               loss_type, normalize_indices, variance_lambda,
                               huber_delta)
    else:
        raise ValueError('Unidentified loss type %s. Currently supported loss '
                         'types are: regression_mse, regression_huber, '
                         'classification .' % loss_type)
    return loss
예제 #11
0
def lu_solve(lower_upper, perm, rhs,
             validate_args=False,
             name=None):
  """Solves systems of linear eqns `A X = RHS`, given LU factorizations.

  Note: this function does not verify the implied matrix is actually invertible
  nor is this condition checked even when `validate_args=True`.

  Args:
    lower_upper: `lu` as returned by `tf.linalg.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `lower_upper = L + U - eye`.
    perm: `p` as returned by `tf.linag.lu`, i.e., if
      `matmul(P, matmul(L, U)) = X` then `perm = argmax(P)`.
    rhs: Matrix-shaped float `Tensor` representing targets for which to solve;
      `A X = RHS`. To handle vector cases, use:
      `lu_solve(..., rhs[..., tf.newaxis])[..., 0]`.
    validate_args: Python `bool` indicating whether arguments should be checked
      for correctness. Note: this function does not verify the implied matrix is
      actually invertible, even when `validate_args=True`.
      Default value: `False` (i.e., don't validate arguments).
    name: Python `str` name given to ops managed by this object.
      Default value: `None` (i.e., 'lu_solve').

  Returns:
    x: The `X` in `A @ X = RHS`.

  #### Examples

  ```python
  import numpy as np
  import tensorflow as tf
  import tensorflow_probability as tfp

  x = [[[1., 2],
        [3, 4]],
       [[7, 8],
        [3, 4]]]
  inv_x = tfp.math.lu_solve(*tf.linalg.lu(x), rhs=tf.eye(2))
  tf.assert_near(tf.matrix_inverse(x), inv_x)
  # ==> True
  ```

  """

  with tf.name_scope(name or 'lu_solve'):
    lower_upper = tf.convert_to_tensor(
        lower_upper, dtype_hint=tf.float32, name='lower_upper')
    perm = tf.convert_to_tensor(perm, dtype_hint=tf.int32, name='perm')
    rhs = tf.convert_to_tensor(
        rhs, dtype_hint=lower_upper.dtype, name='rhs')

    assertions = _lu_solve_assertions(lower_upper, perm, rhs, validate_args)
    if assertions:
      with tf.control_dependencies(assertions):
        lower_upper = tf.identity(lower_upper)
        perm = tf.identity(perm)
        rhs = tf.identity(rhs)

    if rhs.shape.ndims == 2 and perm.shape.ndims == 1:
      # Both rhs and perm have scalar batch_shape.
      permuted_rhs = tf.gather(rhs, perm, axis=-2)
    else:
      # Either rhs or perm have non-scalar batch_shape or we can't determine
      # this information statically.
      rhs_shape = tf.shape(rhs)
      broadcast_batch_shape = tf.broadcast_dynamic_shape(
          rhs_shape[:-2],
          tf.shape(perm)[:-1])
      d, m = rhs_shape[-2], rhs_shape[-1]
      rhs_broadcast_shape = tf.concat([broadcast_batch_shape, [d, m]], axis=0)

      # Tile out rhs.
      broadcast_rhs = tf.broadcast_to(rhs, rhs_broadcast_shape)
      broadcast_rhs = tf.reshape(broadcast_rhs, [-1, d, m])

      # Tile out perm and add batch indices.
      broadcast_perm = tf.broadcast_to(perm, rhs_broadcast_shape[:-1])
      broadcast_perm = tf.reshape(broadcast_perm, [-1, d])
      broadcast_batch_size = tf.reduce_prod(broadcast_batch_shape)
      broadcast_batch_indices = tf.broadcast_to(
          tf.range(broadcast_batch_size)[:, tf.newaxis],
          [broadcast_batch_size, d])
      broadcast_perm = tf.stack([broadcast_batch_indices, broadcast_perm],
                                axis=-1)

      permuted_rhs = tf.gather_nd(broadcast_rhs, broadcast_perm)
      permuted_rhs = tf.reshape(permuted_rhs, rhs_broadcast_shape)

    lower = tf.linalg.set_diag(
        tf.linalg.band_part(lower_upper, num_lower=-1, num_upper=0),
        tf.ones(tf.shape(lower_upper)[:-1], dtype=lower_upper.dtype))
    return linear_operator_util.matrix_triangular_solve_with_broadcast(
        lower_upper,  # Only upper is accessed.
        linear_operator_util.matrix_triangular_solve_with_broadcast(
            lower, permuted_rhs),
        lower=False)
예제 #12
0
 def batch_gather(params, indices, axis=-1):
   return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)
예제 #13
0
파일: extensions.py 프로젝트: zsunpku/trax
    def wrapper(*args):
        """Wrapper that wraps/unwraps args, retvals, and runs the function."""
        if _pmap_config.devices() is not None:
            raise ValueError(
                "Found a surrounding pmap. Nested pmap is not supported "
                "yet.")
        # TODO(wangpeng): Maybe we should use `asarray` to convert everything to
        # ndarray first.
        args = _np_to_tf(args)

        flattened_input_args = tf.nest.flatten(args)
        flattened_per_device_args = [[] for _ in devices]
        for arg in flattened_input_args:
            if isinstance(arg, tf.Tensor):
                # TODO(nareshmodi): Try and use the dynamic shape instead.
                if (not arg.shape.rank) or arg.shape[0] != len(devices):
                    # TODO(nareshmodi): Fix this restriction
                    raise ValueError(
                        "Input tensors need to have a first dimension equal to "
                        "the number of devices; got tensor of shape %s and %s devices"
                        % (arg.shape, len(devices)))
                # NOTE: Alternatively use tf.split, and place the split tensors on the
                # appropriate device. The best solution for this is to have an API that
                # splits a tensor across devices.
                for j, device in enumerate(devices):
                    updated_arg = tf.gather(arg, j)
                    # TODO(wangpeng): Investigate whether we need a tf.identity for TPU.
                    if not has_tpu:
                        with tf.device(device):
                            updated_arg = tf.identity(updated_arg)
                    flattened_per_device_args[j].append(updated_arg)
            elif isinstance(arg, ShardedNdArray):
                for device_args, tensor in zip(flattened_per_device_args,
                                               arg.tensors):
                    device_args.append(tensor)
            else:
                for device_args in flattened_per_device_args:
                    device_args.append(arg)

        all_per_device_args = [
            tf.nest.pack_sequence_as(args, device_args)
            for device_args in flattened_per_device_args
        ]

        with pmap_config(axis_name, devices):
            results = pmap_fn(all_per_device_args)

        # Rewrap things. This can probably be written better.
        flattened_results = [tf.nest.flatten(result) for result in results]
        final_tree = []

        # TODO(nareshmodi): assert all items in flattened_results have the same
        # structures

        for i in range(len(flattened_results[0])):
            tensors = []
            for j, device in enumerate(devices):
                assert isinstance(flattened_results[j][i], tf.Tensor), (
                    "currently only tensor return items are supported")
                tensors.append(flattened_results[j][i])
            final_tree.append(ShardedNdArray(tensors))

        final_actual_result = tf.nest.pack_sequence_as(results[0], final_tree)

        # Workaround b/121383831
        if (has_tpu and isinstance(final_actual_result, list)
                and len(final_actual_result)
                == 1) and not _orig_result_is_list.val:
            return final_actual_result[0]
        else:
            return final_actual_result
예제 #14
0
def index_remapping_gather(params,
                           indices,
                           axis=0,
                           indices_axis=0,
                           name='index_remapping_gather'):
    """Gather values from `axis` of `params` using `indices_axis` of `indices`.

  The shape of `indices` must broadcast to that of `params` when
  their `indices_axis` and `axis` (respectively) are aligned:

  ```python
  # params.shape:
  [p[0],  ..., ...,         p[axis], ..., ..., p[rank(params)] - 1])
  # indices.shape:
        [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1])
  ```

  In particular, `params` must have at least as many
  leading dimensions as `indices` (`axis >= indices_axis`), and at least as many
  trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`).

  The `result` has the same shape as `params`, except that the dimension
  of size `p[axis]` is replaced by one of size `i[indices_axis]`:

  ```python
  # result.shape:
  [p[0],  ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]]
  ```

  In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and
  `indices_axis = 1`, the result is given by

   ```python
   # alignment is:                       v axis
   # params.shape    ==   [p[0], p[1], p[2], p[3], p[4]]
   # indices.shape   ==         [i[0], i[1], i[2]]
   #                                     ^ indices_axis
   result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m]
  ```

  Args:
    params:  `N-D` `Tensor` (`N > 0`) from which to gather values.
      Number of dimensions must be known statically.
    indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose
      shape broadcasts to that of `params` as described above.
    axis: Python `int` axis of `params` from which to gather.
    indices_axis: Python `int` axis of `indices` to align with the `axis`
      over which `params` is gathered.
    name: String name for scoping created ops.

  Returns:
    `Tensor` composed of elements of `params`.

  Raises:
    ValueError: If shape/rank requirements are not met.
  """
    with tf.name_scope(name):
        params = tf.convert_to_tensor(params, name='params')
        indices = tf.convert_to_tensor(indices, name='indices')

        params_ndims = params.shape.ndims
        indices_ndims = indices.shape.ndims
        # `axis` dtype must match ndims, which are 64-bit Python ints.
        axis = tf.get_static_value(tf.convert_to_tensor(axis, dtype=tf.int64))
        indices_axis = tf.get_static_value(
            tf.convert_to_tensor(indices_axis, dtype=tf.int64))

        if params_ndims is None:
            raise ValueError(
                'Rank of `params`, must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if axis is None:
            raise ValueError(
                '`axis` must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if indices_axis is None:
            raise ValueError(
                '`indices_axis` must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if indices_axis > axis:
            raise ValueError(
                '`indices_axis` should be <= `axis`, but was {} > {}'.format(
                    indices_axis, axis))

        if params_ndims < 1:
            raise ValueError(
                'Rank of params should be `> 0`, but was {}'.format(
                    params_ndims))

        if indices_ndims is not None and indices_ndims < 1:
            raise ValueError(
                'Rank of indices should be `> 0`, but was {}'.format(
                    indices_ndims))

        if (indices_ndims is not None
                and (indices_ndims - indices_axis > params_ndims - axis)):
            raise ValueError(
                '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - '
                'indices_axis` ({} - {}), but was not.'.format(
                    params_ndims, axis, indices_ndims, indices_axis))

        # `tf.gather` requires the axis to be the rightmost batch ndim. So, we
        # transpose `indices_axis` to be the rightmost dimension of `indices`...
        transposed_indices = dist_util.move_dimension(indices,
                                                      source_idx=indices_axis,
                                                      dest_idx=-1)

        # ... and `axis` to be the corresponding (aligned as in the docstring)
        # dimension of `params`.
        broadcast_indices_ndims = indices_ndims + (axis - indices_axis)
        transposed_params = dist_util.move_dimension(
            params, source_idx=axis, dest_idx=broadcast_indices_ndims - 1)

        # Next we broadcast `indices` so that its shape has the same prefix as
        # `params.shape`.
        transposed_params_shape = prefer_static.shape(transposed_params)
        result_shape = prefer_static.concat([
            transposed_params_shape[:broadcast_indices_ndims - 1],
            prefer_static.shape(indices)[indices_axis:indices_axis + 1],
            transposed_params_shape[broadcast_indices_ndims:]
        ],
                                            axis=0)
        broadcast_indices = prefer_static.broadcast_to(
            transposed_indices, result_shape[:broadcast_indices_ndims])

        result_t = tf.gather(transposed_params,
                             broadcast_indices,
                             batch_dims=broadcast_indices_ndims - 1,
                             axis=broadcast_indices_ndims - 1)
        return dist_util.move_dimension(result_t,
                                        source_idx=broadcast_indices_ndims - 1,
                                        dest_idx=axis)
    def rasterize(self,
                  scene_vertices=None,
                  scene_attributes=None,
                  scene_triangles=None,
                  name=None):
        """Rasterizes the scene.

    This rasterizer estimates which triangle is associated with each pixel using
    OpenGL. Then the value of attributes are estimated using Tensorflow,
    allowing to get gradients flowing through the attributes. Attributes can be
    depth, appearance, or more generally, any K-dimensional representation. Note
    that similarly to algorithms like Iterative Closest Point (ICP), not having
    gradients through correspondence does not prevent from optimizing the scene
    geometry. Custom gradients can be defined to alleviate this property.

    Note:
      In the following, A1 to An are optional batch dimensions.

    Args:
      scene_vertices: A tensor of shape `[A1, ..., An, V, 3]` containing batches
        of `V` vertices, each defined by a 3D point.
      scene_attributes: A tensor of shape `[A1, ..., An, V, K]` containing
        batches of `V` vertices, each associated with K-dimensional attributes.
      scene_triangles: A tensor of shape `[T, 3]` containing `T` triangles, each
        associated with 3 vertices from `scene_vertices`
      name: A name for this op. Defaults to 'triangle_rasterizer_rasterize'.

    Returns:
      A tensor of shape `[A1, ..., An, H, W, K]` containing batches of images of
      height `H` and width `W`, where each pixel contains attributes rasterized
      from the scene.
    """
        with tf.compat.v1.name_scope(
                name, "triangle_rasterizer_rasterize",
            (scene_vertices, scene_attributes, scene_triangles)):
            scene_vertices = tf.convert_to_tensor(value=scene_vertices)
            scene_attributes = tf.convert_to_tensor(value=scene_attributes)
            scene_triangles = tf.convert_to_tensor(value=scene_triangles)

            shape.check_static(tensor=scene_vertices,
                               tensor_name="scene_vertices",
                               has_rank_greater_than=1,
                               has_dim_equals=((-1, 3)))
            shape.compare_batch_dimensions(tensors=(scene_vertices,
                                                    scene_attributes),
                                           last_axes=-2,
                                           tensor_names=("vertex_positions",
                                                         "vertex_attributes"),
                                           broadcast_compatible=False)
            shape.check_static(tensor=scene_triangles,
                               tensor_name="scene_triangles",
                               has_dim_equals=((-1, 3)))

            batch_dims_triangles = len(scene_triangles.shape[:-2])
            scene_attributes = tf.gather(scene_attributes,
                                         scene_triangles,
                                         axis=-2,
                                         batch_dims=batch_dims_triangles)
            scene_geometry = tf.gather(scene_vertices,
                                       scene_triangles,
                                       axis=-2,
                                       batch_dims=batch_dims_triangles)

            batch_shape = scene_geometry.shape[:-3]
            batch_shape = [_dim_value(dim) for dim in batch_shape]

            background_geometry = tf.broadcast_to(
                self._background_geometry,
                batch_shape + self._background_geometry.shape)
            background_attribute = tf.broadcast_to(
                self._background_attribute,
                batch_shape + self._background_attribute.shape)
            geometry = tf.concat((background_geometry, scene_geometry),
                                 axis=-3)
            attributes = tf.concat((background_attribute, scene_attributes),
                                   axis=-3)

            view_projection_matrix = tf.broadcast_to(
                input=self._view_projection_matrix,
                shape=batch_shape + self._view_projection_matrix.shape)
            rasterized_face = render_ops.rasterize(
                num_points=geometry.shape[-3],
                variable_names=("view_projection_matrix", "triangular_mesh"),
                variable_kinds=("mat", "buffer"),
                variable_values=(view_projection_matrix,
                                 tf.reshape(geometry,
                                            shape=batch_shape + [-1])),
                output_resolution=self._image_size_int,
                vertex_shader=vertex_shader,
                geometry_shader=geometry_shader,
                fragment_shader=fragment_shader)
            triangle_index = tf.cast(rasterized_face[..., 0], tf.int32)
            vertices_per_pixel = tf.gather(geometry,
                                           triangle_index,
                                           axis=-3,
                                           batch_dims=len(batch_shape))
            attributes_per_pixel = tf.gather(attributes,
                                             triangle_index,
                                             axis=-3,
                                             batch_dims=len(batch_shape))
            return glm.perspective_correct_interpolation(
                vertices_per_pixel, attributes_per_pixel, self._pixel_position,
                self._camera_origin, self._look_at, self._camera_up,
                self._field_of_view, self._image_size_glm, self._near_plane,
                self._far_plane, self._bottom_left)
예제 #16
0
    def _parse_predict_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=1.0,
            aug_scale_max=1.0)
        image_scale = image_info[2, :]
        offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  self._output_size, offset)
        masks = input_utils.resize_and_crop_masks(
            tf.expand_dims(masks, axis=-1), image_scale, self._output_size,
            offset)

        # Filters out ground truth boxes that are all zeros.
        indices = input_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        labels = {
            'anchor_boxes': input_anchor.multilevel_boxes,
            'image_info': image_info,
        }
        if self._mode == ModeKeys.PREDICT_WITH_GT:
            # Converts boxes from normalized coordinates to pixel coordinates.
            groundtruths = {
                'source_id':
                data['source_id'],
                'num_detections':
                tf.shape(data['groundtruth_classes']),
                'boxes':
                box_utils.denormalize_boxes(data['groundtruth_boxes'],
                                            image_shape),
                'classes':
                data['groundtruth_classes'],
                # 'masks': tf.squeeze(masks, axis=-1),
                'areas':
                data['groundtruth_area'],
                'is_crowds':
                tf.cast(data['groundtruth_is_crowd'], tf.int32),
            }
            groundtruths['source_id'] = dataloader_utils.process_source_id(
                groundtruths['source_id'])
            groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
                groundtruths, self._max_num_instances)
            # Computes training labels.
            (cls_targets, box_targets,
             num_positives) = anchor_labeler.label_anchors(
                 boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
            # Packs labels for model_fn outputs.
            labels.update({
                'cls_targets': cls_targets,
                'box_targets': box_targets,
                'num_positives': num_positives,
                'groundtruths': groundtruths,
            })
        return image, labels
예제 #17
0
        def collater_fn(batch: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
            batch = mm_collater_fn(batch)

            retrieve_masked = config.get('retrieve_masked', False)

            # Subselect mentions for which to retrieve corresponding memory.
            # We want to sample mentions which are linked, not masked, and not padded.
            scores = tf.random.uniform(
                tf.shape(batch['mention_target_is_masked'])) + 2 * tf.cast(
                    batch['mention_target_weights'], tf.float32)

            if not retrieve_masked:
                scores -= tf.cast(batch['mention_target_is_masked'],
                                  tf.float32)

            _, mention_target_retrieval_indices = tf.math.top_k(
                scores, k=max_retrieval_indices)

            mention_retrieval_indices = tf.gather(
                batch['mention_target_indices'],
                mention_target_retrieval_indices)
            retrieval_mention_mask = tf.gather(
                batch['mention_target_weights'],
                mention_target_retrieval_indices)
            # set weight to 0 for masked retrievals if we do not want to include these
            if not retrieve_masked:
                retrieval_mention_mask *= tf.gather(
                    1 - tf.cast(batch['mention_target_is_masked'], tf.int32),
                    mention_target_retrieval_indices)

            retrieval_mention_start_positions = tf.gather(
                batch['mention_start_positions'], mention_retrieval_indices)
            retrieval_text_identifiers = tf.gather(batch['text_identifiers'],
                                                   mention_retrieval_indices)
            retrieval_mention_hash = mention_preprocess_utils.modified_cantor_pairing(
                tf.cast(retrieval_mention_start_positions, tf.int64),
                retrieval_text_identifiers)
            retrieval_mention_hash = tf.cast(retrieval_mention_hash, tf.int32)

            retrieval_mention_sort_ids = tf.searchsorted(
                memory_hash_sorted, retrieval_mention_hash)

            # Searchsorted does not check whether value is present in array, just
            # finds insertion point. Here we check and set to default retrieval if not
            # present.
            hash_not_present_mask = tf.not_equal(
                retrieval_mention_hash,
                tf.gather(memory_hash_sorted, retrieval_mention_sort_ids))
            hash_not_present = tf.where(hash_not_present_mask)
            update_values = tf.fill((tf.shape(hash_not_present)[0], ),
                                    tf.shape(hash_sorted_idx)[0] - 1)
            retrieval_mention_sort_ids = tf.tensor_scatter_nd_update(
                retrieval_mention_sort_ids, hash_not_present, update_values)

            # Set mask to 0 if no mention is found
            batch['retrieval_mention_mask'] = retrieval_mention_mask * (
                1 - tf.cast(hash_not_present_mask, tf.int32))

            retrieval_mention_ids = tf.gather(hash_sorted_idx,
                                              retrieval_mention_sort_ids)
            retrieval_mention_values = tf.gather(memory_table,
                                                 retrieval_mention_ids)
            # Match passage entity_ids with memory entity ids as sanity check.
            if memory_entity_pattern:
                retrieval_memory_entity_ids = tf.gather(
                    memory_entity_ids, retrieval_mention_ids)
                retrieval_passage_entity_ids = tf.gather(
                    tf.cast(batch['mention_target_ids'], tf.int32),
                    mention_target_retrieval_indices)
                entity_does_not_match = tf.not_equal(
                    retrieval_memory_entity_ids, retrieval_passage_entity_ids)

                batch['entity_does_not_match'] = tf.logical_and(
                    entity_does_not_match,
                    tf.cast(batch['retrieval_mention_mask'], tf.bool))

            batch['retrieval_mention_values'] = retrieval_mention_values
            batch['retrieval_mention_scores'] = tf.ones_like(
                batch['retrieval_mention_mask'])
            batch['retrieval_mention_batch_positions'] = tf.gather(
                batch['mention_batch_positions'], mention_retrieval_indices)
            batch['retrieval_mention_start_positions'] = retrieval_mention_start_positions  # pylint: disable=line-too-long
            batch['retrieval_mention_end_positions'] = tf.gather(
                batch['mention_end_positions'], mention_retrieval_indices)
            batch['mention_retrieval_indices'] = mention_retrieval_indices

            return batch
예제 #18
0
def process_batchwise_mention_targets(
    dense_span_starts: tf.Tensor,
    dense_span_ends: tf.Tensor,
    dense_mention_ids: tf.Tensor,
    dense_linked_mention_mask: tf.Tensor,
    dense_is_masked: tf.Tensor,
    max_mentions: int,
    max_mention_targets: int,
) -> Dict[str, tf.Tensor]:
    """Processes mention targets and subsamples/pads as necessary.

  This function does two things. First, it selects which mentions to mark as
  mentions for mention-aware text encoders (in case the number of mentions
  exceeds the max number of mentions). Second, it selects which linked
  mentions to use as targets for mention objectives. To reduce subsampling and
  padding, the function operates over all mentions in a batch, generating
  flattened arrays. The encoder reconstructs the original mention positions
  from an array which specifies each mention's position in the batch. Linked
  mentions are given priority for sampling.

  Args:
    dense_span_starts: dense mention start positions.
    dense_span_ends: dense mention end positions.
    dense_mention_ids: dense entity ids for linked mentions in passage.
    dense_linked_mention_mask: dense mask for linked mentions in passage.
    dense_is_masked: dense mask for masked positions in passage.
    max_mentions: max number of mentions to be considered in model.
    max_mention_targets: max number of mentions to be used for linking loss.

  Returns:
    Mention starts, mention ends, mention mask,
    mention target indices (into start/end positions),
    mention target ids, mention target weights, mention_target_batch_positions,
    mention_target_start_positions, mention_target_end_positions
  """

    seq_len = tf.shape(dense_span_starts)[1]

    # The linking mask has 1's for every part of the mention, we only
    # want it for starts...
    linking_mask_start_indexed = dense_span_starts * dense_linked_mention_mask

    # values in {0, 1, 2}:
    # 0: not a masking location.
    # 1: a masking location.
    # 2: a masking and linking location.
    prioritized_span_starts = dense_span_starts + linking_mask_start_indexed
    prioritized_span_starts = tf.cast(prioritized_span_starts, tf.float32)

    # Add random [0; 1) values for a uniform sampling in case
    # there are more mention than `max_mentions`
    prioritized_span_starts += tf.random.uniform(
        tf.shape(prioritized_span_starts))

    _, global_start_indices = tf.math.top_k(_flatten(prioritized_span_starts),
                                            k=max_mentions)

    dense_span_starts_flatten = _flatten(dense_span_starts)
    dense_span_ends_at_starts = get_dense_span_ends_from_starts(
        dense_span_starts_flatten, _flatten(dense_span_ends))
    global_end_indices = tf.gather(dense_span_ends_at_starts,
                                   global_start_indices)

    dtype = dense_span_starts.dtype
    mention_batch_positions = tf.math.floordiv(global_start_indices, seq_len)
    mention_batch_positions = tf.cast(mention_batch_positions, dtype=dtype)
    mention_start_positions = tf.math.floormod(global_start_indices, seq_len)
    mention_start_positions = tf.cast(mention_start_positions, dtype=dtype)
    mention_end_positions = tf.math.floormod(global_end_indices, seq_len)
    mention_end_positions = tf.cast(mention_end_positions, dtype=dtype)
    mention_mask = tf.gather(dense_span_starts_flatten, global_start_indices)
    mention_mask = tf.cast(mention_mask, dtype=dtype)
    mention_batch_positions *= mention_mask
    mention_start_positions *= mention_mask
    mention_end_positions *= mention_mask

    mention_target_weights = tf.gather(_flatten(linking_mask_start_indexed),
                                       global_start_indices)
    mention_target_weights = mention_target_weights[:max_mention_targets]
    mention_target_weights = tf.cast(mention_target_weights, dtype=dtype)
    mention_target_indices = tf.range(max_mention_targets, dtype=dtype)
    mention_target_indices = mention_target_indices * mention_target_weights
    mention_target_ids = tf.gather(_flatten(dense_mention_ids),
                                   global_start_indices)
    mention_target_ids = mention_target_ids[:max_mention_targets]
    mention_target_ids = tf.cast(mention_target_ids, dtype=dtype)
    mention_target_ids = mention_target_ids * mention_target_weights
    indices = tf.stack((mention_batch_positions, mention_start_positions),
                       axis=1)
    mention_is_masked = tf.gather_nd(dense_is_masked, indices)
    mention_target_is_masked = mention_is_masked[:max_mention_targets]

    features = {
        'mention_batch_positions': mention_batch_positions,
        'mention_start_positions': mention_start_positions,
        'mention_end_positions': mention_end_positions,
        'mention_mask': mention_mask,
        'mention_is_masked': mention_is_masked,
        'mention_target_ids': mention_target_ids,
        'mention_target_indices': mention_target_indices,
        'mention_target_is_masked': mention_target_is_masked,
    }
    mention_target_features = prepare_mention_target_features(
        mention_batch_positions, mention_start_positions,
        mention_end_positions, mention_mask, mention_target_weights,
        mention_target_indices)
    features.update(mention_target_features)
    return features
  def forward_rates(self,
                    market: pmd.ProcessedMarketData,
                    past_fixing: Optional[types.FloatTensor] = None,
                    name: Optional[str] = None
                    ) -> Tuple[types.DateTensor, types.FloatTensor]:
    """Returns forward rates for the floating leg.

    Args:
      market: An instance of `ProcessedMarketData`.
      past_fixing: An optional `Tensor` of shape compatible with
        `batch_shape + [1]`. Represents the fixings for the cashflows as
        observed at `market.date`.
      name: Python str. The name to give to the ops created by this function.
        Default value: `None` which maps to 'forward_rates'.

    Returns:
      A tuple of two `Tensor`s of shape `batch_shape + [num_cashflows]`
      containing the dates and the corresponding forward rates for each stream
      based on the input market data.
    """
    name = name or (self._name + "_forward_rates")
    with tf.name_scope(name):
      reference_curve = get_discount_curve(
          self._reference_curve_type, market, self._reference_mask)
      valuation_date = dateslib.convert_to_date_tensor(market.date)

      # Previous fixing date
      coupon_start_date_ord = self._coupon_start_dates.ordinal()
      coupon_end_date_ord = self._coupon_end_dates.ordinal()
      valuation_date_ord = valuation_date.ordinal()
      batch_shape = tf.shape(coupon_start_date_ord)[:-1]
      # Broadcast valuation date batch shape for tf.searchsorted
      valuation_date_ord += tf.expand_dims(
          tf.zeros(batch_shape, dtype=tf.int32), axis=-1)
      ind = tf.maximum(tf.searchsorted(coupon_start_date_ord,
                                       valuation_date_ord) - 1, 0)
      # Fixings are assumed to be the same as coupon start dates
      # TODO(b/177047910): add fixing settlement dates.
      # Shape `batch_shape + [1]`
      fixing_dates_ord = tf.gather(
          coupon_start_date_ord, ind,
          batch_dims=len(coupon_start_date_ord.shape) - 1)
      fixing_end_dates_ord = tf.gather(
          coupon_end_date_ord, ind,
          batch_dims=len(coupon_start_date_ord.shape) - 1)
      fixing_dates = dateslib.dates_from_ordinals(fixing_dates_ord)
      fixing_end_dates = dateslib.dates_from_ordinals(fixing_end_dates_ord)
      # Get fixings. Shape batch_shape + [1]
      if past_fixing is None:
        past_fixing = _get_fixings(
            fixing_dates,
            fixing_end_dates,
            self._reference_curve_type,
            self._reference_mask,
            market)
      else:
        past_fixing = tf.convert_to_tensor(past_fixing, dtype=self._dtype,
                                           name="past_fixing")
      forward_rates = reference_curve.forward_rate(
          self._accrual_start_date,
          self._accrual_end_date,
          day_count_fraction=self._daycount_fractions)
      # Shape batch_shape + [num_cashflows]
      forward_rates = tf.where(self._daycount_fractions > 0., forward_rates,
                               tf.zeros_like(forward_rates))
      # If coupon end date is before the valuation date, the payment is in the
      # past. If valuation date is between coupon start date and coupon end
      # date, then the rate has been fixed but not paid. Otherwise the rate is
      # not fixed and should be read from the curve.
      # Shape batch_shape + [num_cashflows]
      forward_rates = tf.where(
          self._coupon_end_dates < valuation_date,
          tf.constant(0, dtype=self._dtype),
          tf.where(self._coupon_start_dates >= valuation_date,
                   forward_rates, past_fixing))
      return  self._coupon_end_dates, forward_rates
예제 #20
0
def _sample_from_edpp(eigenvectors, vector_onehot, seed):
    """Samples a batch of subsets from a DPP given pre-selected elementary DPPs.

  Recall that an elementary DPP is a DPP with eigenvalues all exactly 0 or 1.
  This function implements the second step of standard sampling algorithm for
  DPPs, by sampling subsets based on the E-DPPs obtained by selecting
  `vector_onehot` against the DPP's original eigenvectors.

  Args:
    eigenvectors: A Tensor of `float32` of shape `[..., num_points, num_vecs]`
      representing the eigenvectors of a DPP's L-ensemble matrix, eigenvectors
      in columns. Generally, `num_vecs == num_points`; we name separately to
      distinguish axes.
    vector_onehot:  A Tensor of shape `[..., n_vecs]` whose innermost
      dimension corresponds to 1-hot subset encodings. The subsets represent the
      subset of eigenvectors of the original DPP that define an elementary DPP.
    seed: The random seed.

  Returns:
    samples: A many-hot `bool` Tensor of shape `[..., n_points]`
      representing a batch of 1-hot subset encodings.
  """
    with tf.name_scope('sample_from_edpp'):
        seed = samplers.sanitize_seed(seed)
        # Sort the 1's to the front, and sort corresponding eigenvectors, then mask.
        vector_onehot = tf.cast(vector_onehot, eigenvectors.dtype)
        vector_indices = tf.argsort(vector_onehot,
                                    axis=-1,
                                    direction='DESCENDING')
        vector_onehot = tf.gather(vector_onehot,
                                  vector_indices,
                                  axis=-1,
                                  batch_dims=len(vector_indices.shape) - 1)
        eigenvectors = tf.gather(eigenvectors,
                                 vector_indices,
                                 axis=-1,
                                 batch_dims=len(vector_indices.shape) - 1)
        eigenvectors = eigenvectors * vector_onehot[..., tf.newaxis, :]
        sample_size = tf.reduce_sum(tf.cast(vector_onehot, tf.int32), axis=-1)
        max_sample_size = tf.reduce_max(sample_size)

        d = ps.shape(eigenvectors)[-2]
        n = ps.shape(eigenvectors)[-1]

        # Slice eigvecs to do less work in eager/non-XLA modes.
        if FAST_PATH_ENABLED and not JAX_MODE and (
                tf.executing_eagerly()
                or not control_flow_util.GraphOrParentsInXlaContext(
                    tf1.get_default_graph())):
            # We can save some work in non-XLA contexts by reducing the size of the
            # eigenvectors.
            eigenvectors = eigenvectors[..., :max_sample_size]
            n = max_sample_size

        def cond(i, *_):
            return i < max_sample_size

        def body(i, vecs, cur_sample, seed):
            sample_seed, next_seed = samplers.split_seed(seed)
            # squared norm at each coord across active subspace
            is_active = (i < sample_size)
            coord_prob = tf.reduce_sum(tf.square(vecs), axis=-1)
            coord_logits = tf.where(is_active[..., tf.newaxis],
                                    tf.math.log(coord_prob), 0.)

            idx = categorical.Categorical(logits=coord_logits).sample(
                seed=sample_seed)
            new_vecs = tf.where(
                (tf.range(n) < sample_size[..., tf.newaxis, tf.newaxis] - i -
                 1) & ~cur_sample[..., tf.newaxis],
                _orthogonal_complement_e_i(vecs,
                                           i=tf.where(is_active, idx, 0),
                                           gram_schmidt_iters=max_sample_size -
                                           i), 0.)
            # Since range(n) may have unknown shape in the stmt above, we clarify.
            tensorshape_util.set_shape(new_vecs, vecs.shape)
            vecs = tf.where(is_active[..., tf.newaxis, tf.newaxis], new_vecs,
                            vecs)
            cur_sample = (cur_sample |
                          (tf.equal(tf.range(d), idx[..., tf.newaxis])
                           & is_active[..., tf.newaxis]))
            return i + 1, vecs, cur_sample, next_seed

        _, _, sample, _ = tf.while_loop(
            cond, body,
            (tf.zeros([], tf.int32, name='i'), eigenvectors,
             tf.zeros(ps.shape(eigenvectors)[:-1], dtype=tf.bool), seed))

        return tf.cast(sample, tf.int32)
예제 #21
0
def prepare_grid(*, times, time_step, dtype, num_time_steps=None,
                 times_grid=None):
  """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.
    num_time_steps: Number of points on the grid. If suppied, a uniform grid
      is constructed for `[time_step, times[-1] - time_step]` consisting of
      max(0, num_time_steps - len(times)) points that is then concatenated with
      times. This parameter guarantees the number of points on the time grid
      is `max(len(times), num_time_steps)` and that `times` are included to the
      grid.
      Default value: `None`, which means that a uniform grid is created.
       containing all points from 'times` and the uniform grid of points between
       `[0, times[-1]]` with grid size equal to `time_step`.
    times_grid: An optional rank 1 `Tensor` representing time discretization
      grid. If `times` are not on the grid, then the nearest points from the
      grid are used.
      Default value: `None`, which means that times grid is computed using
      `time_step` and `num_time_steps`.

  Returns:
    Tuple `(all_times, mask, time_indices)`.
    `all_times` is a 1-D real `Tensor`. If `num_time_steps` is supplied the
      shape of the output is `max(num_time_steps, len(times))`. Otherwise
      consists of all points from 'times` and the uniform grid of points between
      `[0, times[-1]]` with grid size equal to `time_step`.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
      which elements of 'all_times' correspond to THE values from `times`.
      Guarantees that times[0]=0 and mask[0]=False.
    `time_indices`. An integer `Tensor` of the same shape as `times` indicating
    `times` indices in `all_times`.
  """
  if times_grid is None:
    if num_time_steps is None:
      all_times, time_indices = _grid_from_time_step(
          times=times, time_step=time_step, dtype=dtype)
    else:
      all_times, time_indices = _grid_from_num_times(
          times=times, time_step=time_step, num_time_steps=num_time_steps)
  else:
    all_times = times_grid
    time_indices = tf.searchsorted(times_grid, times)
    # Adjust indices to bring `times` closer to `times_grid`.
    times_diff_1 = tf.gather(times_grid, time_indices) - times
    times_diff_2 = tf.gather(
        times_grid, tf.math.maximum(time_indices-1, 0)) - times
    time_indices = tf.where(
        tf.math.abs(times_diff_2) > tf.math.abs(times_diff_1),
        time_indices,
        tf.math.maximum(time_indices - 1, 0))
  # Create a boolean mask to identify the iterations that have to be recorded.
  # Use `tf.scatter_nd` because it handles duplicates. Also we first create
  # an int64 Tensor and then create a boolean mask becase scatter_nd with
  # booleans is currently not supported on GPUs.
  mask = tf.scatter_nd(
      indices=tf.expand_dims(tf.cast(time_indices, dtype=tf.int64), axis=1),
      updates=tf.fill(tf.shape(times), 1),
      shape=tf.shape(all_times, out_type=tf.int64))
  mask = tf.where(mask > 0, True, False)

  return all_times, mask, time_indices
예제 #22
0
def find_interval_index(query_xs,
                        interval_lower_xs,
                        last_interval_is_closed=False,
                        dtype=None,
                        name=None):
  """Function to find the index of the interval where query points lies.

  Given a list of adjacent half-open intervals [x_0, x_1), [x_1, x_2), ...,
  [x_{n-1}, x_n), [x_n, inf), described by a list [x_0, x_1, ..., x_{n-1}, x_n].
  Return the index where the input query points lie. If x >= x_n, n is returned,
  and if x < x_0, -1 is returned. If `last_interval_is_closed` is set to `True`,
  the last interval [x_{n-1}, x_n] is interpreted as closed (including x_n).

  #### Example

  ```python
  interval_lower_xs = [0.25, 0.5, 1.0, 2.0, 3.0]
  query_xs = [0.25, 3.0, 5.0, 0.0, 0.5, 0.8]
  result = find_interval_index(query_xs, interval_lower_xs)
  # result == [0, 4, 4, -1, 1, 1]
  ```

  Args:
    query_xs: Rank 1 real `Tensor` of any size, the list of x coordinates for
      which the interval index is to be found. The values must be strictly
      increasing.
    interval_lower_xs: Rank 1 `Tensor` of the same shape and dtype as
      `query_xs`. The values x_0, ..., x_n that define the interval starts.
    last_interval_is_closed: If set to `True`, the last interval is interpreted
      as closed.
    dtype: Optional `tf.Dtype`. If supplied, the dtype for `query_xs` and
      `interval_lower_xs`.
      Default value: None which maps to the default dtype inferred from
      `query_xs`.
    name: Optional name of the operation.

  Returns:
    A tensor that matches the shape of `query_xs` with dtype=int32 containing
    the indices of the intervals containing query points. `-1` means the query
    point lies before all intervals and `n-1` means that the point lies in the
    last half-open interval (if `last_interval_is_closed` is `False`) or that
    the point lies to the right of all intervals (if `last_interval_is_closed`
    is `True`).
  """
  name = name or 'find_interval_index'
  with tf.name_scope(name):
    # TODO(b/138988951): add ability to validate that intervals are increasing.
    # TODO(b/138988951): validate that if last_interval_is_closed, input size
    # must be > 1.
    query_xs = tf.convert_to_tensor(query_xs, dtype=dtype)
    dtype = dtype or query_xs.dtype
    interval_lower_xs = tf.convert_to_tensor(interval_lower_xs, dtype=dtype)

    # Result assuming that last interval is half-open.
    indices = tf.searchsorted(interval_lower_xs, query_xs, side='right') - 1

    # Handling the branch if the last interval is closed.
    last_index = tf.shape(interval_lower_xs)[-1] - 1
    last_x = tf.gather(interval_lower_xs, [last_index], axis=-1)
    # should_cap is a tensor true where a cell is true iff indices is the last
    # index at that cell and the query x <= the right boundary of the last
    # interval.
    should_cap = tf.logical_and(
        tf.equal(indices, last_index), tf.less_equal(query_xs, last_x))

    # cap to last_index if the query x is not in the last interval, otherwise,
    # cap to last_index - 1.
    caps = last_index - tf.cast(should_cap, dtype=tf.int32)

    return tf.compat.v1.where(last_interval_is_closed,
                              tf.minimum(indices, caps), indices)
예제 #23
0
 def _forward_event_shape_tensor(self, input_shape):
     perm = self._make_perm(tf.size(input_shape), self.perm)
     return tf.gather(input_shape, perm)
예제 #24
0
def percentile(x,
               q,
               axis=None,
               interpolation=None,
               keepdims=False,
               validate_args=False,
               preserve_gradients=True,
               keep_dims=None,
               name=None):
    """Compute the `q`-th percentile(s) of `x`.

  Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
  way from the minimum to the maximum in a sorted copy of `x`.

  The values and distances of the two nearest neighbors as well as the
  `interpolation` parameter will determine the percentile if the normalized
  ranking does not match the location of `q` exactly.

  This function is the same as the median if `q = 50`, the same as the minimum
  if `q = 0` and the same as the maximum if `q = 100`.

  Multiple percentiles can be computed at once by using `1-D` vector `q`.
  Dimension zero of the returned `Tensor` will index the different percentiles.

  Compare to `numpy.percentile`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.  If `axis` is not `None`,
      `x` must have statically known number of dimensions.
    q:  Scalar or vector `Tensor` with values in `[0, 100]`. The percentile(s).
    axis:  Optional `0-D` or `1-D` integer `Tensor` with constant values. The
      axis that index independent samples over which to return the desired
      percentile.  If `None` (the default), treat every dimension as a sample
      dimension, returning a scalar.
    interpolation : {'nearest', 'linear', 'lower', 'higher', 'midpoint'}.
      Default value: 'nearest'.  This specifies the interpolation method to
      use when the desired quantile lies between two data points `i < j`:
        * linear: i + (j - i) * fraction, where fraction is the fractional part
          of the index surrounded by i and j.
        * lower: `i`.
        * higher: `j`.
        * nearest: `i` or `j`, whichever is nearest.
        * midpoint: (i + j) / 2.
      `linear` and `midpoint` interpolation do not work with integer dtypes.
    keepdims:  Python `bool`. If `True`, the last dimension is kept with size 1
      If `False`, the last dimension is removed from the output shape.
    validate_args:  Whether to add runtime checks of argument validity. If
      False, and arguments are incorrect, correct behavior is not guaranteed.
    preserve_gradients:  Python `bool`.  If `True`, ensure that gradient w.r.t
      the percentile `q` is preserved in the case of linear interpolation.
      If `False`, the gradient will be (incorrectly) zero when `q` corresponds
      to a point in `x`.
    keep_dims: deprecated, use keepdims instead.
    name:  A Python string name to give this `Op`.  Default is 'percentile'

  Returns:
    A `(rank(q) + N - len(axis))` dimensional `Tensor` of same dtype as `x`, or,
      if `axis` is `None`, a `rank(q)` `Tensor`.  The first `rank(q)` dimensions
      index quantiles for different values of `q`.

  Raises:
    ValueError:  If argument 'interpolation' is not an allowed type.
    ValueError:  If interpolation type not compatible with `dtype`.

  #### Examples

  ```python
  # Get 30th percentile with default ('nearest') interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30.)
  ==> 2.0

  # Get 30th percentile with 'linear' interpolation.
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=30., interpolation='linear')
  ==> 1.9

  # Get 30th and 70th percentiles with 'lower' interpolation
  x = [1., 2., 3., 4.]
  tfp.stats.percentile(x, q=[30., 70.], interpolation='lower')
  ==> [1., 3.]

  # Get 100th percentile (maximum).  By default, this is computed over every dim
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100.)
  ==> 4.

  # Treat the leading dim as indexing samples, and find the 100th quantile (max)
  # over all such samples.
  x = [[1., 2.]
       [3., 4.]]
  tfp.stats.percentile(x, q=100., axis=[0])
  ==> [3., 4.]
  ```

  """
    keepdims = keepdims if keep_dims is None else keep_dims
    del keep_dims
    name = name or 'percentile'
    allowed_interpolations = {
        'linear', 'lower', 'higher', 'nearest', 'midpoint'
    }

    if interpolation is None:
        interpolation = 'nearest'
    else:
        if interpolation not in allowed_interpolations:
            raise ValueError(
                'Argument `interpolation` must be in {}. Found {}.'.format(
                    allowed_interpolations, interpolation))

    with tf.name_scope(name):
        x = tf.convert_to_tensor(x, name='x')

        if (interpolation in {'linear', 'midpoint'}
                and dtype_util.is_integer(x.dtype)):
            raise TypeError(
                '{} interpolation not allowed with dtype {}'.format(
                    interpolation, x.dtype))

        # Double is needed here and below, else we get the wrong index if the array
        # is huge along axis.
        q = tf.cast(q, tf.float64)
        _get_static_ndims(q, expect_ndims_no_more_than=1)

        if validate_args:
            q = distribution_util.with_dependencies([
                assert_util.assert_rank_in(q, [0, 1]),
                assert_util.assert_greater_equal(q, tf.cast(0., tf.float64)),
                assert_util.assert_less_equal(q, tf.cast(100., tf.float64))
            ], q)

        # Move `axis` dims of `x` to the rightmost, call it `y`.
        if axis is None:
            y = tf.reshape(x, [-1])
        else:
            x_ndims = _get_static_ndims(x,
                                        expect_static=True,
                                        expect_ndims_at_least=1)
            axis = _make_static_axis_non_negative_list(axis, x_ndims)
            y = _move_dims_to_flat_end(x, axis, x_ndims, right_end=True)

        frac_at_q_or_below = q / 100.

        # Sort (in ascending order) everything which allows multiple calls to sort
        # only once (under the hood) and use CSE.
        sorted_y = tf.sort(y, axis=-1, direction='ASCENDING')

        d = tf.cast(tf.shape(y)[-1], tf.float64)

        def _get_indices(interp_type):
            """Get values of y at the indices implied by interp_type."""
            if interp_type == 'lower':
                indices = tf.math.floor((d - 1) * frac_at_q_or_below)
            elif interp_type == 'higher':
                indices = tf.math.ceil((d - 1) * frac_at_q_or_below)
            elif interp_type == 'nearest':
                indices = tf.round((d - 1) * frac_at_q_or_below)
            # d - 1 will be distinct from d in int32, but not necessarily double.
            # So clip to avoid out of bounds errors.
            return tf.clip_by_value(tf.cast(indices, tf.int32), 0,
                                    tf.shape(y)[-1] - 1)

        if interpolation in ['nearest', 'lower', 'higher']:
            gathered_y = tf.gather(sorted_y,
                                   _get_indices(interpolation),
                                   axis=-1)
        elif interpolation == 'midpoint':
            gathered_y = 0.5 * (
                tf.gather(sorted_y, _get_indices('lower'), axis=-1) +
                tf.gather(sorted_y, _get_indices('higher'), axis=-1))
        elif interpolation == 'linear':
            # Copy-paste of docstring on interpolation:
            # linear: i + (j - i) * fraction, where fraction is the fractional part
            # of the index surrounded by i and j.
            larger_y_idx = _get_indices('higher')
            exact_idx = (d - 1) * frac_at_q_or_below
            if preserve_gradients:
                # If q corresponds to a point in x, we will initially have
                # larger_y_idx == smaller_y_idx.
                # This results in the gradient w.r.t. fraction being zero (recall `q`
                # enters only through `fraction`...and see that things cancel).
                # The fix is to ensure that smaller_y_idx and larger_y_idx are always
                # separated by exactly 1.
                smaller_y_idx = tf.maximum(larger_y_idx - 1, 0)
                larger_y_idx = tf.minimum(smaller_y_idx + 1,
                                          tf.shape(y)[-1] - 1)
                fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx
            else:
                smaller_y_idx = _get_indices('lower')
                fraction = tf.math.ceil(
                    (d - 1) * frac_at_q_or_below) - exact_idx

            fraction = tf.cast(fraction, y.dtype)
            gathered_y = (
                tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction) +
                tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction)

        # Propagate NaNs
        if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64):
            # Apparently tf.is_nan doesn't like other dtypes
            nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis)
            right_rank_matched_shape = tf.pad(tf.shape(nan_batch_members),
                                              paddings=[[0, tf.rank(q)]],
                                              constant_values=1)
            nan_batch_members = tf.reshape(nan_batch_members,
                                           shape=right_rank_matched_shape)
            nan = np.array(np.nan, dtype_util.as_numpy_dtype(gathered_y.dtype))
            gathered_y = tf.where(nan_batch_members, nan, gathered_y)

        # Expand dimensions if requested
        if keepdims:
            if axis is None:
                ones_vec = tf.ones(shape=[
                    _get_best_effort_ndims(x) + _get_best_effort_ndims(q)
                ],
                                   dtype=tf.int32)
                gathered_y *= tf.ones(ones_vec, dtype=x.dtype)
            else:
                gathered_y = _insert_back_keepdims(gathered_y, axis)

        # If q is a scalar, then result has the right shape.
        # If q is a vector, then result has trailing dim of shape q.shape, which
        # needs to be rotated to dim 0.
        return distribution_util.rotate_transpose(gathered_y, tf.rank(q))
예제 #25
0
 def _inverse_event_shape_tensor(self, output_shape):
     perm = self._make_perm(tf.size(output_shape), tf.argsort(self.perm))
     return tf.gather(output_shape, perm)
예제 #26
0
def barrier_price(*,
                  volatilities,
                  strikes,
                  expiries,
                  spots,
                  barriers,
                  rebates=None,
                  discount_rates=None,
                  continuous_dividends=None,
                  cost_of_carries=None,
                  is_barrier_down=None,
                  is_knock_out=None,
                  is_call_options=None,
                  dtype=None,
                  name=None):
    """Prices barrier options in a Black-Scholes Model.

  Computes the prices of options with a single barrier in Black-Scholes world as
  described in Ref. [1]. Note that the barrier is applied continuously.

  #### Example

  This example is taken from Ref. [2], Page 154.

  ```python
  import tf_quant_finance as tff

  dtype = np.float32
  discount_rates = np.array([.08, .08])
  continuous_dividends = np.array([.04, .04])
  spots = np.array([100., 100.])
  strikes = np.array([90., 90.])
  barriers = np.array([95. 95.])
  rebates = np.array([3. 3.])
  volatilities = np.array([.25, .25])
  expiries = np.array([.5, .5])
  barriers_type = np.array([5, 1])
  is_barrier_down = np.array([True, False])
  is_knock_out = np.array([False, False])
  is_call_option = np.array([True, True])

  price = tff.black_scholes.barrier_price(
    discount_rates, continuous_dividends, spots, strikes,
    barriers, rebates, volatilities,
    expiries, is_barrier_down, is_knock_out, is_call_options)

  # Expected output
  #  `Tensor` with values [9.024, 7.7627]
  ```

  #### References

  [1]: Lee Clewlow, Javier Llanos, Chris Strickland, Caracas Venezuela
    Pricing Exotic Options in a Black-Scholes World, 1994
    https://warwick.ac.uk/fac/soc/wbs/subjects/finance/research/wpaperseries/1994/94-54.pdf
  [2]: Espen Gaarder Haug, The Complete Guide to Option Pricing Formulas,
    2nd Edition, 1997

  Args:
    volatilities: Real `Tensor` of any shape and dtype. The volatilities to
      expiry of the options to price.
    strikes: A real `Tensor` of the same dtype and compatible shape as
      `volatilities`. The strikes of the options to be priced.
    expiries: A real `Tensor` of same dtype and compatible shape as
      `volatilities`. The expiry of each option. The units should be such that
      `expiry * volatility**2` is dimensionless.
    spots: A real `Tensor` of any shape that broadcasts to the shape of the
      `volatilities`. The current spot price of the underlying.
    barriers: A real `Tensor` of same dtype as the `volatilities` and of the
      shape that broadcasts with `volatilities`. The barriers of each option.
    rebates: A real `Tensor` of same dtype as the `volatilities` and of the
      shape that broadcasts with `volatilities`. For knockouts, this is a
      fixed cash payout in case the barrier is breached. For knockins, this is a
      fixed cash payout in case the barrier level is not breached. In the former
      case, the rebate is paid immediately on breach whereas in the latter, the
      rebate is paid at the expiry of the option.
      Default value: `None` which maps to no rebates.
    discount_rates: A real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`.
      Discount rates, or risk free rates.
      Default value: `None`, equivalent to discount_rate = 0.
    continuous_dividends: A real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`. A
      continuous dividend rate paid by the underlier. If `None`, then
      defaults to zero dividends.
      Default value: `None`, equivalent to zero dividends.
    cost_of_carries: A optional real `Tensor` of same dtype as the
      `volatilities` and of the shape that broadcasts with `volatilities`.
      Cost of storing a physical commodity, the cost of interest paid when
      long, or the opportunity cost, or the cost of paying dividends when short.
      If not `None`, `continuous_dividends` is calculated as r - c,
      where r are the `discount_rates` and c is `cost_of_carries`.
    is_barrier_down: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if barrier is below asset
      price at expiration.
      Default value: `True`.
    is_knock_out: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if option is knock out
      else false.
      Default value: `True`.
    is_call_options: A real `Tensor` of `boolean` values and of the shape
      that broadcasts with `volatilities`. True if option is call else
      false.
      Default value: `True`.
    dtype: Optional `tf.DType`. If supplied, the dtype to be used for conversion
      of any supplied non-`Tensor` arguments to `Tensor`.
      Default value: `None` which maps to the default dtype inferred by
      TensorFlow.
    name: str. The name for the ops created by this function.
      Default value: `None` which is mapped to the default name `barrier_price`.
  Returns:
    option_prices: A `Tensor` of same shape as `spots`. The approximate price of
    the barriers option under black scholes.
  """
    # The computation is done as in Ref [2] where each integral is split into
    # two matrices. The first matrix contains the algebraic terms and the second
    # matrix contains the probability distribution terms. Masks are used to filter
    # appropriate terms for calculating the integral. Then a dot product of each
    # row in the matricies coupled with the masks work to calculate the prices of
    # the barriers option.
    if (continuous_dividends is not None) and (cost_of_carries is not None):
        raise ValueError(
            'At most one of continuous_dividends and cost of carries '
            'may be supplied')
    with tf.name_scope(name or 'barrier_price'):
        spots = tf.convert_to_tensor(spots, dtype=dtype, name='spots')
        dtype = spots.dtype
        strikes = tf.convert_to_tensor(strikes, dtype=dtype, name='strikes')
        volatilities = tf.convert_to_tensor(volatilities,
                                            dtype=dtype,
                                            name='volatilities')
        expiries = tf.convert_to_tensor(expiries, dtype=dtype, name='expiries')
        barriers = tf.convert_to_tensor(barriers, dtype=dtype, name='barriers')
        if rebates is not None:
            rebates = tf.convert_to_tensor(rebates,
                                           dtype=dtype,
                                           name='rebates')
        else:
            rebates = tf.zeros_like(spots, dtype=dtype, name='rebates')

        # Convert all to tensor and enforce float dtype where required
        if discount_rates is not None:
            discount_rates = tf.convert_to_tensor(discount_rates,
                                                  dtype=dtype,
                                                  name='discount_rates')
        else:
            discount_rates = tf.convert_to_tensor(0.0,
                                                  dtype=dtype,
                                                  name='discount_rates')

        if continuous_dividends is None:
            continuous_dividends = tf.convert_to_tensor(
                0.0, dtype=dtype, name='continuous_dividends')

        if cost_of_carries is not None:
            cost_of_carries = tf.convert_to_tensor(cost_of_carries,
                                                   dtype=dtype,
                                                   name='cost_of_carries')
        else:
            cost_of_carries = discount_rates - continuous_dividends

        if is_barrier_down is None:
            is_barrier_down = tf.constant(1, name='is_barrier_down')
        else:
            is_barrier_down = tf.convert_to_tensor(is_barrier_down,
                                                   dtype=tf.bool,
                                                   name='is_barrier_down')
            is_barrier_down = tf.where(is_barrier_down, 1, 0)
        if is_knock_out is None:
            is_knock_out = tf.constant(1, name='is_knock_out')
        else:
            is_knock_out = tf.convert_to_tensor(is_knock_out,
                                                dtype=tf.bool,
                                                name='is_knock_out')
            is_knock_out = tf.where(is_knock_out, 1, 0)
        if is_call_options is None:
            is_call_options = tf.constant(1, name='is_call_options')
        else:
            is_call_options = tf.convert_to_tensor(is_call_options,
                                                   dtype=tf.bool,
                                                   name='is_call_options')
            is_call_options = tf.where(is_call_options, 1, 0)

        # Indices which range from 0-7 are used to select the appropriate
        # mask for each barrier
        indices = tf.bitwise.left_shift(is_barrier_down,
                                        2) + tf.bitwise.left_shift(
                                            is_knock_out, 1) + is_call_options

        # Masks select the appropriate terms for integral approximations
        # Integrals are seperated by algebraic terms and probability
        # distribution terms. This give 12 different terms per matrix
        # (6 integrals, 2 terms each)
        # shape = [8, 12]
        mask_matrix_greater_strike = tf.constant([
            [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0],  # up and in put
            [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],  # up and in call
            [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1],  # up and out put
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],  # up and out call
            [0, 0, 1, 1, -1, -1, 1, 1, 0, 0, 1, 1],  # down and in put
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],  # down and in call
            [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1],  # down and out put
            [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1]
        ])  # down and out call

        mask_matrix_lower_strike = tf.constant([
            [0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],  # up and in put
            [0, 0, 1, 1, -1, -1, 1, 1, 1, 1, 0, 0],  # up and in call
            [1, 1, 0, 0, -1, -1, 0, 0, 0, 0, 1, 1],  # up and out put
            [1, 1, -1, -1, 1, 1, -1, -1, 0, 0, 1, 1],  # up and out call
            [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],  # down and in put
            [1, 1, -1, -1, 0, 0, 1, 1, 1, 1, 0, 0],  # down and in call
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],  # down and out put
            [0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1]
        ])  # down and out call

        # Create masks
        # Masks are shape [strikes.shape, 12]
        masks_lower = tf.gather(mask_matrix_lower_strike, indices, axis=0)
        masks_greater = tf.gather(mask_matrix_greater_strike, indices, axis=0)
        strikes_greater = tf.expand_dims(strikes > barriers, axis=-1)
        masks = tf.where(strikes_greater, masks_greater, masks_lower)
        masks = tf.cast(masks, dtype=dtype)
        one = tf.constant(1, dtype=dtype)
        call_or_put = tf.cast(tf.where(tf.equal(is_call_options, 0), -one,
                                       one),
                              dtype=dtype)
        below_or_above = tf.cast(tf.where(tf.equal(is_barrier_down, 0), -one,
                                          one),
                                 dtype=dtype)

        # Calculate params for integrals
        sqrt_var = volatilities * tf.math.sqrt(expiries)
        mu = (cost_of_carries) - ((volatilities**2) / 2)
        lamda = 1 + (mu / (volatilities**2))
        x = (tf.math.log(spots / strikes) / (sqrt_var)) + (lamda * sqrt_var)
        x1 = (tf.math.log(spots / barriers) / (sqrt_var)) + (lamda * sqrt_var)
        y = (tf.math.log((barriers**2) / (spots * strikes)) /
             (sqrt_var)) + (lamda * sqrt_var)
        y1 = (tf.math.log(barriers / spots) / (sqrt_var)) + (lamda * sqrt_var)
        b = ((mu**2) +
             (2 * (volatilities**2) * discount_rates)) / (volatilities**2)
        z = (tf.math.log(barriers / spots) / (sqrt_var)) + (b * sqrt_var)
        a = mu / (volatilities**2)

        # Other params used for integrals
        discount_rates_exponent = tf.math.exp(-discount_rates * expiries,
                                              name='discount_rates_exponent')
        continuous_dividends_exponent = tf.math.exp(
            (cost_of_carries - discount_rates) * expiries,
            name='continuous_dividends_exponent')
        barriers_ratio = tf.math.divide(barriers, spots, name='barriers_ratio')
        spots_term = call_or_put * spots * continuous_dividends_exponent
        strikes_term = call_or_put * strikes * discount_rates_exponent

        # rank is used to stack elements and reduce_sum
        strike_rank = strikes.shape.rank

        # Constructing Matrix with first and second algebraic terms for each
        # integral [strike.shape, 12]
        terms_mat = tf.stack(
            (spots_term, -strikes_term, spots_term, -strikes_term, spots_term *
             (barriers_ratio**(2 * lamda)), -strikes_term *
             (barriers_ratio**((2 * lamda) - 2)), spots_term *
             (barriers_ratio**(2 * lamda)), -strikes_term *
             (barriers_ratio**((2 * lamda) - 2)), rebates *
             discount_rates_exponent, -rebates * discount_rates_exponent *
             (barriers_ratio**((2 * lamda) - 2)), rebates *
             (barriers_ratio**(a + b)), rebates * (barriers_ratio**(a - b))),
            name='term_matrix',
            axis=strike_rank)

        # Constructing Matrix with first and second norm for each integral
        # [strikes.shape, 12]
        cdf_mat = tf.stack(
            (call_or_put * x, call_or_put *
             (x - sqrt_var), call_or_put * x1, call_or_put *
             (x1 - sqrt_var), below_or_above * y, below_or_above *
             (y - sqrt_var), below_or_above * y1, below_or_above *
             (y1 - sqrt_var), below_or_above *
             (x1 - sqrt_var), below_or_above * (y1 - sqrt_var),
             below_or_above * z, below_or_above * (z - (2 * b * sqrt_var))),
            name='cdf_matrix',
            axis=strike_rank)
        cdf_mat = _ncdf(cdf_mat)
        # Calculating and returning price for each option
        return tf.reduce_sum(masks * terms_mat * cdf_mat, axis=strike_rank)
예제 #27
0
 def _indices_to_words(self, indices):
     return tf.gather(self._vocab_tensor, indices)
def _generate_detections_per_image(boxes,
                                   scores,
                                   max_total_size=100,
                                   nms_iou_threshold=0.3,
                                   score_threshold=0.05,
                                   pre_nms_num_boxes=5000):
  """Generate the final detections per image given the model outputs.

  Args:
    boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
      predictions on all feature levels. The N is the number of total anchors on
      all levels.
    scores: a tensor with shape [N, num_classes], which stacks class probability
      on all feature levels. The N is the number of total anchors on all levels.
      The num_classes is the number of classes predicted by the model. Note that
      the class_outputs here is the raw score.
    max_total_size: a scalar representing maximum number of boxes retained over
      all classes.
    nms_iou_threshold: a float representing the threshold for deciding whether
      boxes overlap too much with respect to IOU.
    score_threshold: a float representing the threshold for deciding when to
      remove boxes based on score.
    pre_nms_num_boxes: an int number of top candidate detections per class
      before NMS.

  Returns:
    nms_boxes: `float` Tensor of shape [max_total_size, 4] representing top
      detected boxes in [y1, x1, y2, x2].
    nms_scores: `float` Tensor of shape [max_total_size] representing sorted
      confidence scores for detected boxes. The values are between [0, 1].
    nms_classes: `int` Tensor of shape [max_total_size] representing classes for
      detected boxes.
    valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
      boxes are valid detections.
  """
  nmsed_boxes = []
  nmsed_scores = []
  nmsed_classes = []
  num_classes_for_box = boxes.get_shape().as_list()[1]
  num_classes = scores.get_shape().as_list()[1]
  for i in range(num_classes):
    boxes_i = boxes[:, min(num_classes_for_box-1, i)]
    scores_i = scores[:, i]

    # Obtains pre_nms_num_boxes before running NMS.
    scores_i, indices = tf.nn.top_k(
        scores_i, k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
    boxes_i = tf.gather(boxes_i, indices)

    (nmsed_indices_i,
     nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
         tf.cast(boxes_i, tf.float32),
         tf.cast(scores_i, tf.float32),
         max_total_size,
         iou_threshold=nms_iou_threshold,
         score_threshold=score_threshold,
         pad_to_max_output_size=True,
         name='nms_detections_' + str(i))
    nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i)
    nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i)
    # Sets scores of invalid boxes to -1.
    nmsed_scores_i = tf.where(
        tf.less(tf.range(max_total_size), [nmsed_num_valid_i]), nmsed_scores_i,
        -tf.ones_like(nmsed_scores_i))
    nmsed_classes_i = tf.fill([max_total_size], i)
    nmsed_boxes.append(nmsed_boxes_i)
    nmsed_scores.append(nmsed_scores_i)
    nmsed_classes.append(nmsed_classes_i)
  # Concats results from all classes and sort them.
  nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
  nmsed_scores = tf.concat(nmsed_scores, axis=0)
  nmsed_classes = tf.concat(nmsed_classes, axis=0)
  nmsed_scores, indices = tf.nn.top_k(
      nmsed_scores,
      k=max_total_size,
      sorted=True)
  nmsed_boxes = tf.gather(nmsed_boxes, indices)
  nmsed_classes = tf.gather(nmsed_classes, indices)
  valid_detections = tf.reduce_sum(
      input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
  return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
예제 #29
0
    def _parse_train_data(self, data):
        """Parse data for ShapeMask training."""
        classes = data['groundtruth_classes']
        boxes = data['groundtruth_boxes']
        masks = data['groundtruth_instance_masks']
        is_crowds = data['groundtruth_is_crowd']
        # Skips annotations with `is_crowd` = True.
        if self._skip_crowd_during_training and self._is_training:
            num_groundtrtuhs = tf.shape(classes)[0]
            with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
                indices = tf.cond(
                    tf.greater(tf.size(is_crowds), 0),
                    lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
                    lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
            classes = tf.gather(classes, indices)
            boxes = tf.gather(boxes, indices)
            masks = tf.gather(masks, indices)

        # Gets original image and its size.
        image = data['image']
        image_shape = tf.shape(image)[0:2]

        # If not using category, makes all categories with id = 0.
        if not self._use_category:
            classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)

        # Normalizes image with mean and std pixel values.
        image = input_utils.normalize_image(image)

        # Flips image randomly during training.
        if self._aug_rand_hflip:
            image, boxes, masks = input_utils.random_horizontal_flip(
                image, boxes, masks)

        # Converts boxes from normalized coordinates to pixel coordinates.
        boxes = box_utils.denormalize_boxes(boxes, image_shape)

        # Resizes and crops image.
        image, image_info = input_utils.resize_and_crop_image(
            image,
            self._output_size,
            self._output_size,
            aug_scale_min=self._aug_scale_min,
            aug_scale_max=self._aug_scale_max)
        image_scale = image_info[2, :]
        offset = image_info[3, :]

        # Resizes and crops boxes and masks.
        boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
                                                  self._output_size, offset)

        # Filters out ground truth boxes that are all zeros.
        indices = input_utils.get_non_empty_box_indices(boxes)
        boxes = tf.gather(boxes, indices)
        classes = tf.gather(classes, indices)
        masks = tf.gather(masks, indices)

        # Assigns anchors.
        input_anchor = anchor.Anchor(self._min_level, self._max_level,
                                     self._num_scales, self._aspect_ratios,
                                     self._anchor_size, self._output_size)
        anchor_labeler = anchor.AnchorLabeler(input_anchor,
                                              self._match_threshold,
                                              self._unmatched_threshold)
        (cls_targets, box_targets,
         num_positives) = anchor_labeler.label_anchors(
             boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))

        # Sample groundtruth masks/boxes/classes for mask branch.
        num_masks = tf.shape(masks)[0]
        mask_shape = tf.shape(masks)[1:3]

        # Pad sampled boxes/masks/classes to a constant batch size.
        padded_boxes = input_utils.pad_to_fixed_size(boxes,
                                                     self._num_sampled_masks)
        padded_classes = input_utils.pad_to_fixed_size(classes,
                                                       self._num_sampled_masks)
        padded_masks = input_utils.pad_to_fixed_size(masks,
                                                     self._num_sampled_masks)

        # Randomly sample groundtruth masks for mask branch training. For the image
        # without groundtruth masks, it will sample the dummy padded tensors.
        rand_indices = tf.random.shuffle(
            tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
        rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1))
        rand_indices = rand_indices[0:self._num_sampled_masks]
        rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])

        sampled_boxes = tf.gather(padded_boxes, rand_indices)
        sampled_classes = tf.gather(padded_classes, rand_indices)
        sampled_masks = tf.gather(padded_masks, rand_indices)
        # Jitter the sampled boxes to mimic the noisy detections.
        sampled_boxes = box_utils.jitter_boxes(
            sampled_boxes, noise_scale=self._box_jitter_scale)
        sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
        # Compute mask targets in feature crop. A feature crop fully contains a
        # sampled box.
        mask_outer_boxes = box_utils.compute_outer_boxes(
            sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
        mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes,
                                                self._output_size)
        # Compensate the offset of mask_outer_boxes to map it back to original image
        # scale.
        mask_outer_boxes_ori = mask_outer_boxes
        mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
        mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0),
                                        [1, 2])
        norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
            mask_outer_boxes_ori, mask_shape)

        # Set sampled_masks shape to [batch_size, height, width, 1].
        sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1),
                                tf.float32)
        mask_targets = tf.image.crop_and_resize(
            sampled_masks,
            norm_mask_outer_boxes_ori,
            box_indices=tf.range(self._num_sampled_masks),
            crop_size=[self._mask_crop_size, self._mask_crop_size],
            method='bilinear',
            extrapolation_value=0,
            name='train_mask_targets')
        mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
                                tf.ones_like(mask_targets),
                                tf.zeros_like(mask_targets))
        mask_targets = tf.squeeze(mask_targets, axis=-1)
        if self._up_sample_factor > 1:
            fine_mask_targets = tf.image.crop_and_resize(
                sampled_masks,
                norm_mask_outer_boxes_ori,
                box_indices=tf.range(self._num_sampled_masks),
                crop_size=[
                    self._mask_crop_size * self._up_sample_factor,
                    self._mask_crop_size * self._up_sample_factor
                ],
                method='bilinear',
                extrapolation_value=0,
                name='train_mask_targets')
            fine_mask_targets = tf.where(
                tf.greater_equal(fine_mask_targets, 0.5),
                tf.ones_like(fine_mask_targets),
                tf.zeros_like(fine_mask_targets))
            fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
        else:
            fine_mask_targets = mask_targets

        # If bfloat16 is used, casts input image to tf.bfloat16.
        if self._use_bfloat16:
            image = tf.cast(image, dtype=tf.bfloat16)

        valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
        if self._mask_train_class == 'all':
            mask_is_valid = valid_image * tf.ones_like(sampled_classes,
                                                       tf.int32)
        else:
            # Get the intersection of sampled classes with training splits.
            mask_valid_classes = tf.cast(
                tf.expand_dims(
                    class_utils.coco_split_class_ids(self._mask_train_class),
                    1), sampled_classes.dtype)
            match = tf.reduce_any(
                tf.equal(tf.expand_dims(sampled_classes, 0),
                         mask_valid_classes), 0)
            mask_is_valid = valid_image * tf.cast(match, tf.int32)

        # Packs labels for model_fn outputs.
        labels = {
            'cls_targets': cls_targets,
            'box_targets': box_targets,
            'anchor_boxes': input_anchor.multilevel_boxes,
            'num_positives': num_positives,
            'image_info': image_info,
            # For ShapeMask.
            'mask_boxes': sampled_boxes,
            'mask_outer_boxes': mask_outer_boxes,
            'mask_targets': mask_targets,
            'fine_mask_targets': fine_mask_targets,
            'mask_classes': sampled_classes,
            'mask_is_valid': mask_is_valid,
        }
        return image, labels
예제 #30
0
 def _indices_to_words(self, indices):
   return tf.gather(self._vocab_tensor, indices)
예제 #31
0
 def adjust_day(year, month, day):
     is_leap = date_utils.is_leap_year(year)
     days_in_months = tf.constant(_DAYS_IN_MONTHS_COMBINED, tf.int32)
     max_days = tf.gather(
         days_in_months, month + 12 * tf.dtypes.cast(is_leap, np.int32))
     return tf.math.minimum(day, max_days)