Ejemplo n.º 1
0
  def _compute_is_bus_day_table(self):
    """Computes and caches "is business day" table."""
    if self._table_cache.is_bus_day is not None:
      return self._table_cache.is_bus_day

    with tf.init_scope():
      ordinals = tf.range(self._ordinal_offset,
                          self._ordinal_offset + self._calendar_size)
      # Apply weekend mask
      week_days = (ordinals - 1) % 7
      is_holiday = tf.gather(self._weekend_mask, week_days)

      # Apply holidays
      if self._holidays is not None:
        indices = self._holidays.ordinal() - self._ordinal_offset
        ones_at_indices = tf.scatter_nd(
            tf.expand_dims(indices, axis=-1), tf.ones_like(indices),
            is_holiday.shape)
        is_holiday = tf.bitwise.bitwise_or(is_holiday, ones_at_indices)

      # Add a business day at the beginning and at the end, i.e. at 31 Dec of
      # start_year-1 and at 1 Jan of end_year+1. This trick is to avoid dealing
      # with special cases on boundaries.
      # For example, for Following and Preceding conventions we'd need a special
      # value that means "unknown" in the tables. More complicated conventions
      # then combine the Following and Preceding tables, and would need special
      # treatment of the "unknown" values.
      # With these "fake" business days, all computations are automatically
      # correct, unless we land on those extra days - for this reason we add
      # assertions in all API calls before returning.
      is_bus_day_table = tf.concat([[1], 1 - is_holiday, [1]], axis=0)
      self._table_cache.is_bus_day = is_bus_day_table
    return is_bus_day_table
Ejemplo n.º 2
0
def _get_q_slice(q, k, ind, b=None, batch_shape=None):
    """Returns `q1[i]` or `q0[j]` for a batch of indices `i` or `j`."""
    q_ind = tf.concat([ind, tf.expand_dims(tf.gather_nd(k, ind), -1)], axis=1)
    b_updates = tf.gather_nd(q, q_ind)
    if b is None:
        return tf.scatter_nd(ind, b_updates, batch_shape)
    return tf.tensor_scatter_nd_update(b, ind, b_updates)
    def _assemble(self, global_scatter=True):
        a_local_values = self._local_stiffness_matrix_assembler()
        b_local_values = self._local_load_vector_assembler()
        # Don't scatter into global matrices
        if not global_scatter:
            return a_local_values, b_local_values

        global_stiffness_indices = itertools.chain.from_iterable(
            [itertools.product(row, row) for row in self.mesh.elements])

        if len(self.batch_shape) == 0:
            # add an extra leading dimension onto a_local
            a_local_values = a_local_values[tf.newaxis, ...]
            batch_shape = [1]
        else:
            batch_shape = self.batch_shape

        # flatten a_local_values
        flat_shape = tf.concat(
            (batch_shape,
             [a_local_values.shape[-3] * a_local_values.shape[-1]**2]),
            axis=0)

        a_local_values = tf.reshape(a_local_values,
                                    flat_shape,
                                    name='flatten_local_stiffness_matrix')

        # unpack global_stiffness_indices
        global_stiffness_indices = [*global_stiffness_indices]
        A = tf.map_fn(
            lambda x: tf.scatter_nd(
                global_stiffness_indices,
                x,
                shape=[self.mesh.n_nodes, self.mesh.n_nodes]), a_local_values)
        A = tf.squeeze(A)

        global_load_indices = itertools.chain.from_iterable([
            zip(row, np.zeros(3, dtype=np.intp)) for row in self.mesh.elements
        ])

        b = tf.scatter_nd([*global_load_indices],
                          tf.reshape(b_local_values, [-1],
                                     name='flatten_local_load'),
                          shape=[self.mesh.n_nodes, 1],
                          name='scatter_nd_to_global_load')

        return A, b
Ejemplo n.º 4
0
 def testIndexedSlices(self):
   dtype = tf.int64
   iss = tf.IndexedSlices(values=tf.ones([2, 3], dtype=dtype),
                          indices=tf.constant([1, 9]),
                          dense_shape=[10, 3])
   a = array_ops.array(iss, copy=False)
   expected = tf.scatter_nd([[1], [9]], tf.ones([2, 3], dtype=dtype), [10, 3])
   self.assertAllEqual(expected, a)
Ejemplo n.º 5
0
    def get_output_shape(self, input_shape):
        if self._output_shape is not None:
            return self._output_shape

        axis = tf.math.mod(self._axis, tf.shape(input_shape)[0])
        return input_shape + tf.scatter_nd(
            indices=[[axis]],
            updates=[len(self._quantiles) - input_shape[axis]],
            shape=tf.shape(input_shape))
Ejemplo n.º 6
0
 def compute_output_shape(self, input_shape):
     output_shape = input_shape
     if self._topk is not None:
         axis = tf.math.mod(self._axis, tf.shape(input_shape)[0])
         output_shape = input_shape + tf.scatter_nd(
             indices=[[axis]],
             updates=[self._topk - input_shape[axis]],
             shape=tf.shape(input_shape))
     return tf.TensorShape(output_shape)
Ejemplo n.º 7
0
def moveaxis(a, source, destination):  # pylint: disable=missing-docstring
    """Raises ValueError if source, destination not in (-ndim(a), ndim(a))."""
    if not source and not destination:
        return a

    a = asarray(a).data

    if isinstance(source, int):
        source = (source, )
    if isinstance(destination, int):
        destination = (destination, )

    a_rank = utils._maybe_static(tf.rank(a))  # pylint: disable=protected-access

    def _correct_axis(axis, rank):
        if axis < 0:
            return axis + rank
        return axis

    source = tuple(_correct_axis(axis, a_rank) for axis in source)
    destination = tuple(_correct_axis(axis, a_rank) for axis in destination)

    if a.shape.rank is not None:
        perm = [i for i in range(a_rank) if i not in source]
        for dest, src in sorted(zip(destination, source)):
            assert dest <= len(perm)
            perm.insert(dest, src)
    else:
        r = tf.range(a_rank)

        def _remove_indices(a, b):
            """Remove indices (`b`) from `a`."""
            items = tf.unstack(tf.sort(tf.stack(b)), num=len(b))

            i = 0
            result = []

            for item in items:
                result.append(a[i:item])
                i = item + 1

            result.append(a[i:])

            return tf.concat(result, 0)

        minus_sources = _remove_indices(r, source)
        minus_dest = _remove_indices(r, destination)

        perm = tf.scatter_nd(tf.expand_dims(minus_dest, 1), minus_sources,
                             [a_rank])
        perm = tf.tensor_scatter_nd_update(perm,
                                           tf.expand_dims(destination,
                                                          1), source)
    a = tf.transpose(a, perm)

    return utils.tensor_to_ndarray(a)
Ejemplo n.º 8
0
def _prepare_grid(times, times_grid, *params):
    """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    times_grid: An optional rank 1 `Tensor` representing time discretization
      grid. If `times` are not on the grid, then the nearest points from the
      grid are used.
    *params: Parameters of the Heston model. Either scalar `Tensor`s of the
      same `dtype` or instances of `PiecewiseConstantFunc`.

  Returns:
    Tuple `(all_times, mask)`.
    `all_times` is a 1-D real `Tensor` containing all points from 'times`, the
    uniform grid of points between `[0, times[-1]]` with grid size equal to
    `time_step`, and jump locations of piecewise constant parameters The
    `Tensor` is sorted in ascending order and may contain duplicates.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
    which elements of 'all_times' correspond to THE values from `times`.
    Guarantees that times[0]=0 and mask[0]=False.
  """
    if times_grid is None:
        additional_times = []
        for param in params:
            if hasattr(param, 'is_piecewise_constant'):
                if param.is_piecewise_constant:
                    # Flatten all jump locations
                    additional_times.append(
                        tf.reshape(param.jump_locations(), [-1]))
        zeros = tf.constant([0], dtype=times.dtype)
        all_times = tf.concat([zeros] + [times] + additional_times, axis=0)
        all_times = tf.sort(all_times)
        time_indices = tf.searchsorted(all_times, times, out_type=tf.int32)
    else:
        all_times = times_grid
        time_indices = tf.searchsorted(times_grid, times, out_type=tf.int32)
        # Adjust indices to bring `times` closer to `times_grid`.
        times_diff_1 = tf.gather(times_grid, time_indices) - times
        times_diff_2 = tf.gather(times_grid,
                                 tf.nn.relu(time_indices - 1)) - times
        time_indices = tf.where(
            tf.math.abs(times_diff_2) > tf.math.abs(times_diff_1),
            time_indices, tf.nn.relu(time_indices - 1))
    # Create a boolean mask to identify the iterations that have to be recorded.
    mask = tf.scatter_nd(indices=tf.expand_dims(tf.cast(time_indices,
                                                        dtype=tf.int64),
                                                axis=1),
                         updates=tf.fill(tf.shape(times), True),
                         shape=tf.shape(all_times, out_type=tf.int64))
    return all_times, mask
Ejemplo n.º 9
0
 def _get_reset_state_indices():
     reset_indices_obs = tf.nest.map_structure(
         lambda t: tf.gather_nd(t, reset_indices), observation)
     # shape: [num_indices_to_reset, ...]
     reset_indices_state = self.get_initial_state(
         reset_indices_obs, batch_size=tf.shape(reset_indices)[0])
     # Scatter tensors in `reset_indices_state` to shape: [num_timesteps,
     # batch_size, ...]
     return tf.nest.map_structure(
         lambda reset_tensor: tf.scatter_nd(indices=reset_indices,
                                            updates=reset_tensor,
                                            shape=done.shape.as_list() +
                                            reset_tensor.shape.as_list(
                                            )[1:]), reset_indices_state)
Ejemplo n.º 10
0
    def test_sanity_check_sweep_over_features(self):
        num_outputs = 100
        num_features = 3
        design_matrix, true_weights, targets = self.evaluate(
            self._random_regression_task(
                num_outputs=num_outputs,
                num_features=num_features,
                # Specify weights with a clear sparsity pattern.
                weights=tf.convert_to_tensor([10., 0., -10.]),
                seed=test_util.test_seed()))

        sampler = dynamic_spike_and_slab.DynamicSpikeSlabSampler(
            design_matrix,
            # Ensure the probability of keeping an irrelevant feature is tiny.
            nonzero_prior_prob=1e-6)
        initial_state = sampler._initialize_sampler_state(
            targets=targets,
            nonzeros=tf.convert_to_tensor([True, True, True]),
            observation_noise_variance=1.)
        final_state = self.evaluate(
            sampler._resample_all_features(initial_state,
                                           seed=test_util.test_seed()))

        # Check that we recovered the true sparsity pattern and approximate weights.
        weights_posterior_precision = (sampler.x_transpose_x +
                                       sampler.weights_prior_precision)
        conditional_weights_mean = _compute_conditional_weights_mean(
            final_state.nonzeros, weights_posterior_precision,
            final_state.x_transpose_y)
        self.assertAllEqual(final_state.nonzeros, [True, False, True])
        indices = tf.where(final_state.nonzeros)
        conditional_weights_mean = tf.scatter_nd(indices,
                                                 conditional_weights_mean,
                                                 true_weights.shape)
        self.assertAllClose(conditional_weights_mean,
                            true_weights,
                            rtol=0.05,
                            atol=0.15)

        posterior = sampler._get_conditional_posterior(final_state)
        posterior_variances, posterior_weights = self.evaluate(
            posterior.sample(seed=test_util.test_seed()))
        self.assertAllFinite(posterior_variances)
        self.assertAllFinite(posterior_weights)
Ejemplo n.º 11
0
  def test_categorical_resampler_chi2(self):
    # Test categorical resampler using chi-squared test.

    num_probs = 50
    num_distributions = 3
    unnormalized_probs = tfd.Uniform(
        low=np.float64(0),
        high=np.float64(1.)).sample([num_distributions, num_probs],
                                    seed=test_util.test_seed())
    probs = unnormalized_probs / tf.reduce_sum(
        unnormalized_probs, axis=-1, keepdims=True)

    # chi-squared test is valid as long as `num_samples` is
    # large compared to `num_probs`.
    num_particles = 10000
    num_samples = 2

    sample = resample_independent(
        tf.math.log(dist_util.move_dimension(probs,
                                             source_idx=-1,
                                             dest_idx=0)),
        num_particles,
        [num_samples],
        seed=test_util.test_seed())
    sample = dist_util.move_dimension(sample,
                                      source_idx=0,
                                      dest_idx=-1)
    # TODO(dpiponi): reimplement this test in vectorized form rather than with
    # loops.
    for sample_index in range(num_samples):
      for prob_index in range(num_distributions):
        counts = tf.scatter_nd(
            indices=sample[sample_index, prob_index][:, tf.newaxis],
            updates=tf.ones(num_particles, dtype=tf.int32),
            shape=[num_probs])
        expected_samples = probs[prob_index] * num_particles
        chi2 = tf.reduce_sum(
            (tf.cast(counts, tf.float64) -
             expected_samples)**2 / expected_samples,
            axis=-1)
        self.assertAllLess(
            tfd.Chi2(df=np.float64(num_probs - 1)).cdf(chi2),
            0.9999)
Ejemplo n.º 12
0
def _get_endpoint_a(i, j, q1_i, q0_j, batch_shape):
    """Determine the beginning of the interval, `a`."""
    # if i < 0: a = q0[j]
    i_lt_0 = tf.less(i, 0)
    ind = tf.where(i_lt_0)
    a_update = tf.gather_nd(q0_j, ind)
    a = tf.scatter_nd(ind, a_update, batch_shape)

    # elif j < 0: a = q1[i]
    j_lt_0 = tf.less(j, 0)
    ind = tf.where(j_lt_0)
    a_update = tf.gather_nd(q1_i, ind)
    a = tf.tensor_scatter_nd_update(a, ind, a_update)

    # else: a = max(q0[j], q1[i])
    ind = tf.where(~(i_lt_0 | j_lt_0))
    q_max = tf.maximum(q0_j, q1_i)
    a_update = tf.gather_nd(q_max, ind)
    a = tf.tensor_scatter_nd_update(a, ind, a_update)
    return a
Ejemplo n.º 13
0
    def test_categorical_resampler_chi2(self):
        # Test categorical resampler using chi-squared test.
        if self.use_xla and tf.executing_eagerly():
            self.skipTest('No need to test XLA under all execution regimes.')

        num_probs = 50
        num_distributions = 3
        unnormalized_probs = tfd.Uniform(low=self.dtype(0),
                                         high=self.dtype(1.)).sample(
                                             [num_distributions, num_probs],
                                             seed=42)
        probs = unnormalized_probs / tf.reduce_sum(
            unnormalized_probs, axis=-1, keepdims=True)

        # chi-squared test is valid as long as `num_samples` is
        # large compared to `num_probs`.
        num_particles = 10000
        num_samples = 2

        sample = self.maybe_compiler(resample_independent)(tf.math.log(
            dist_util.move_dimension(probs, source_idx=-1,
                                     dest_idx=0)), num_particles,
                                                           [num_samples])
        sample = dist_util.move_dimension(sample, source_idx=0, dest_idx=-1)
        # TODO(dpiponi): reimplement this test in vectorized form rather than with
        # loops.
        for sample_index in range(num_samples):
            for prob_index in range(num_distributions):
                counts = tf.scatter_nd(indices=sample[sample_index,
                                                      prob_index][:,
                                                                  tf.newaxis],
                                       updates=tf.ones(num_particles,
                                                       dtype=tf.int32),
                                       shape=[num_probs])
                expected_samples = probs[prob_index] * num_particles
                chi2 = tf.reduce_sum(
                    (tf.cast(counts, self.dtype) - expected_samples)**2 /
                    expected_samples,
                    axis=-1)
                self.assertAllLess(
                    tfd.Chi2(df=self.dtype(num_probs - 1)).cdf(chi2), 0.99995)
    def _solve_dirichlet(self):
        # check the boundary conditions are correct
        if self.domain.boundary.boundary_condition_type != 'Dirichlet':
            raise ValueError(
                "_solve_dirichlet must be used with Dirichlet boundary conditions."
            )

        A, b = self._assemble()

        # work out if any batching needs to be done
        if len(self.batch_shape) == 0:
            # add an extra leading dimension onto A
            A = A[tf.newaxis, ...]
            batch_shape = [1]  # false batch shape, squeezed out by end
        else:
            batch_shape = self.batch_shape

        # get the interior of A
        global_stiffness_interior_indices = [
            *itertools.product(self.mesh.interior_node_indices, repeat=2)
        ]

        Ainterior = tf.map_fn(
            lambda x: tf.reshape(
                tf.gather_nd(x, global_stiffness_interior_indices), [
                    len(self.mesh.interior_node_indices),
                    len(self.mesh.interior_node_indices)
                ]), A)

        b_interior = tf.gather_nd(b, [
            *zip(self.mesh.interior_node_indices,
                 [0] * len(self.mesh.interior_node_indices))
        ])

        interior_bound_indices = [
            *itertools.product(self.mesh.interior_node_indices,
                               self.mesh.boundary_node_indices)
        ]

        Aint_bnd = tf.map_fn(
            lambda x: tf.reshape(tf.gather_nd(x, interior_bound_indices), [
                len(self.mesh.interior_node_indices),
                len(self.mesh.boundary_node_indices)
            ]), A)

        bnd_node_indices = np.array(self.mesh.boundary_node_indices,
                                    dtype=np.intp)
        int_node_indices = np.array(self.mesh.interior_node_indices,
                                    dtype=np.intp)

        # get the value on the boundary
        g = self.domain.boundary.g

        # convert the stiffness matrices to operators for batch matmul
        Ainterior_op = tf.linalg.LinearOperatorFullMatrix(Ainterior)
        Aint_bnd_op = tf.linalg.LinearOperatorFullMatrix(Aint_bnd)

        # add the fixed dirichlet conditions to sol
        # ToDo: Batch boundary values
        sol = tf.scatter_nd(bnd_node_indices[:, None],
                            g,
                            shape=[self.mesh.n_nodes, 1])

        b_ = b_interior[..., tf.newaxis] - Aint_bnd_op.matmul(g)
        sol_interior = Ainterior_op.solve(b_)

        # sol_interior has a batched shape [b, n_interior_nodes, 1]
        return tf.squeeze(
            tf.map_fn(
                lambda x: tf.tensor_scatter_nd_add(
                    sol, int_node_indices[:, None], x), sol_interior)
        )[..., tf.
          newaxis]  # kills pesduo-batch dimensions, but keeps output a vector
Ejemplo n.º 15
0
def _accumulating_for_loop(body_fn, initial_state, params, num_iterations,
                           name=None):
  """Version of for_loop with multiple values of num_iterations."""
  # Every tensor in nested tensors (state and Jacobian) gets an extra
  # "accumulating" dimension in front. Functions _create_accumulators etc. below
  # help to work with this dimension.

  with tf.name_scope(name or "accumulating_for_loop"):
    max_iterations = tf.math.reduce_max(num_iterations)
    acc_size = num_iterations.shape[0]

    # num_iteration = [2, 5] -> mask = [0, 0, 1, 0, 0, 1]. Tells when we should
    # increment acc index before writing. Last element won't be used (i = 0..4).
    mask = tf.scatter_nd(indices=tf.expand_dims(num_iterations, axis=-1),
                         updates=tf.ones_like(num_iterations),
                         shape=(max_iterations + 1,))

    n = len(initial_state)

    @tf.custom_gradient
    def inner(*args):
      initial_state, params = args[:n], args[n:]
      def while_cond(i, acc_index, acc_state, acc_jac):
        del acc_index, acc_state, acc_jac
        return i < max_iterations

      def while_body(i, acc_index, acc_state, acc_jac):
        state = _read_from_accumulators(acc_state, acc_index)
        jac = _read_from_accumulators(acc_jac, acc_index)
        with tf.GradientTape(persistent=True) as tape:
          tape.watch(state)
          tape.watch(params)
          next_state = tuple(body_fn(i, state))
        step_jac = _compute_step_jacobian(state, next_state, params, tape)
        next_jac = _multiply_jacobians(step_jac, jac)
        acc_index += mask[i]
        acc_state = _write_to_accumulators(acc_state, next_state, acc_index)
        acc_jac = _write_to_accumulators(acc_jac, next_jac, acc_index)

        return i + 1, acc_index, acc_state, acc_jac

      initial_acc_state = _create_accumulators(initial_state, acc_size)
      initial_acc_state = _write_to_accumulators(initial_acc_state,
                                                 initial_state, 0)

      initial_jac = _make_unit_jacobian(initial_state, params)
      initial_acc_jac = _create_accumulators(initial_jac, acc_size)
      initial_acc_jac = _write_to_accumulators(initial_acc_jac, initial_jac, 0)

      loop_vars = (0, 0, initial_acc_state, initial_acc_jac)

      _, _, final_acc_state, final_acc_jac = tf.compat.v2.while_loop(
          while_cond, while_body, loop_vars=loop_vars,
          maximum_iterations=max_iterations)

      def gradient(*ws):
        # Same as in for_loop, except we need to sum over the accumulating
        # dimension. E.g. if x = for_loop(... num_iterations=[2, 5]) and
        # y = 2*x[0] + 3*x[1], then taking gradient of y will lead to ws having
        # coeffs 2 and 3 in the acc dimension, and we should sum over it.
        ws = [tf.expand_dims(w, axis=-2) for w in ws]
        ws = [ws]  # expand dims on block level as well.

        js, jp = final_acc_jac
        ws_js, ws_jp = _block_matmul(ws, js), _block_matmul(ws, jp)

        ws_js, ws_jp = ws_js[0], ws_jp[0]
        ws_js = [tf.squeeze(t, axis=-2) for t in ws_js]
        ws_jp = [tf.squeeze(t, axis=[-2, -1]) for t in ws_jp]

        # Sum over acc axis.
        ws_js = [tf.math.reduce_sum(t, axis=0) for t in ws_js]
        # ws_jp should be 0-dimensional
        ws_jp = [tf.math.reduce_sum(t) for t in ws_jp]

        return ws_js + ws_jp

      return final_acc_state, gradient

    # tf.custom_gradient can only handle a flat sequence of args.
    args = tuple(initial_state + params)
    return inner(*args)
Ejemplo n.º 16
0
def repeat(a, repeats, axis=None):
    """Repeat elements of the array along specified axes.

  Args:
    a: array_like. Could be an ndarray, a Tensor or any object that can
      be converted to a Tensor using `tf.convert_to_tensor`.
    repeats: 0-d or 1-d array_like. The number of times each element along
      `axis` will be repeated. If this has size 1, each element along the axis
      is repeated the same number of times.
    axis: Optional. The axis along which to repeat. If None, the input array
      is flattened.

  Returns:
    An ndarray with same type as `a`.

  Raises:
    ValueError: If `repeats` has rank > 1 or an incompatible shape.
  """
    a = array_creation.asarray(a)
    repeats = array_creation.asarray(repeats)
    if repeats.ndim > 1:
        raise ValueError('repeats must be a scalar or 1-d array.')
    repeats = ravel(repeats)  # Convert to 1-d array.
    # As per documentation, if axis is None, the input is flattened
    # and a flattened output is returned.
    if axis is None:
        a = ravel(a)
        axis = 0
    elif axis < 0:
        axis += a.ndim

    # Broadcast repeats to match shape of axis.
    if len(repeats) == 1:
        repeats = utils.tensor_to_ndarray(
            tf.tile(repeats.data, [a.shape[axis]]))

    if a.shape[axis] != len(repeats):
        raise ValueError(
            'Shape mismatch. `repeats` expected to have shape ({},)'
            ' but has ({},)'.format(a.shape[axis], len(repeats)))

    # Example:
    #
    # a: [[1, 2, 3],
    #     [4, 5, 6]]
    # axis: 1
    # repeats: [3, 1, 2]
    # Output: [[1, 1, 1, 2, 3, 3],
    #          [4, 4, 4, 5, 6, 6]]
    #
    # Algorithm:
    # 1. Calculate cumulative sum of repeats.
    repeats_cumsum = cumsum(repeats)  # [3, 4, 6]
    # 2. Use `scatter_nd` to generate an indices list for use in `tf.gather`.
    scatter_indices_t = repeats_cumsum[:-1].data  # [3, 4]
    scatter_indices_t = tf.expand_dims(scatter_indices_t, 1)  # [[3], [4]]
    scatter_updates_t = tf.ones([len(repeats) - 1], dtype=tf.int32)  # [1, 1]
    scatter_shape_t = ravel(repeats_cumsum[-1]).data  # [6]
    #    `tf.scatter_nd([[3], [4]], [1, 1], [6])` -> `[0, 0, 0, 1, 1, 0]`
    indices_t = tf.scatter_nd(scatter_indices_t, scatter_updates_t,
                              scatter_shape_t)
    indices_t = tf.cumsum(indices_t)  # [0, 0, 0, 1, 2, 2]
    # 3. Use `tf.gather` to gather indices along `axis`.
    result_t = tf.gather(a, indices_t, axis=axis)

    return utils.tensor_to_ndarray(result_t)
Ejemplo n.º 17
0
def _get_spl_tensor(bins: int, frequencies: tf.Tensor,
                    spls: tf.Tensor) -> tf.Tensor:
    indices = tf.expand_dims(tf.cast(frequencies, dtype=tf.int32), axis=1)
    return tf.scatter_nd(indices, updates=spls, shape=tf.constant([bins]))
Ejemplo n.º 18
0
        def op(x, kernel):
            input_dtype = dtype_util.common_dtype([x, kernel],
                                                  dtype_hint=tf.float32)
            x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
            kernel = tf.convert_to_tensor(kernel,
                                          dtype=input_dtype,
                                          name='kernel')

            batch_shape, event_shape = ps.split(ps.shape(x),
                                                num_or_size_splits=[-1, 3])
            xh, xw, c_in = ps.unstack(event_shape, num=3)

            kernel_shape = ps.shape(kernel)
            c_out = kernel_shape[-1]
            kernel_batch = kernel_shape[:-2]
            assertions = _maybe_validate_input_shapes(
                kernel_shape,
                channels_in=c_in,
                filter_height=fh,
                filter_width=fw,
                validate_args=validate_args)

            with tf.control_dependencies(assertions):

                # If the kernel does not have batch shape, fall back to
                # `conv2d_transpose` (unless dilations > 1, which is not implemented in
                # `conv2d_transpose`).
                if (tf.get_static_value(ps.rank(kernel)) == 2
                        and all(d == 1 for d in dilations)):
                    return _call_conv2d_transpose(x,
                                                  kernel=kernel,
                                                  filter_shape=filter_shape,
                                                  strides=(strides, ) * rank,
                                                  padding=padding,
                                                  dilations=dilations,
                                                  c_out=c_out,
                                                  batch_shape=batch_shape,
                                                  event_shape=event_shape)

                n = ps.maximum(0, ps.rank(x) - 3)
                paddings = ps.pad(padding_vals,
                                  paddings=[[n, 1], [0, 0]],
                                  constant_values=0)

                x_pad = tf.pad(x, paddings=paddings, constant_values=0)
                x_pad_shape = ps.shape(x_pad)[:-3]
                flat_shape = ps.pad(x_pad_shape,
                                    paddings=[[0, 1]],
                                    constant_values=-1)
                flat_x = tf.reshape(x_pad, shape=flat_shape)

                idx, s = im2row_index(
                    (xh + tf.reduce_sum(padding_vals[0]),
                     xw + tf.reduce_sum(padding_vals[1]), c_in),
                    block_shape=(sub_fh, sub_fw),
                    slice_step=(1, 1),
                    dilations=dilations)

                x_ = tf.gather(flat_x, indices=idx, axis=-1)
                im_x = tf.reshape(x_,
                                  shape=ps.concat([x_pad_shape, s], axis=0))

                # Add channels to subkernel indices
                idx_event = event_ind * [[c_in, 1]]
                idx_event_channels = (idx_event[tf.newaxis] + tf.stack(
                    [ps.range(c_in),
                     tf.zeros(
                         (c_in, ), dtype=dtype)], axis=-1)[:, tf.newaxis, :])
                idx_event = tf.squeeze(tf.batch_to_space(idx_event_channels,
                                                         block_shape=[c_in],
                                                         crops=[[0, 0]]),
                                       axis=0)
                idx_event_broadcast = tf.broadcast_to(
                    idx_event,
                    shape=ps.concat(
                        [kernel_batch, ps.shape(idx_event)], axis=0))

                # Add cartesian product of batch indices, since scatter_nd can only be
                # applied to leading dimensions.
                idx_batch = tf.stack(tf.meshgrid(*[
                    ps.range(b_, delta=1, dtype=dtype)
                    for b_ in tf.unstack(kernel_batch)
                ],
                                                 indexing='ij'),
                                     axis=ps.size(kernel_batch))

                idx_batch = tf.cast(idx_batch,
                                    dtype=dtype)  # empty tensor is float

                idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
                    (ps.shape(idx_event)[0], 1), dtype=dtype)
                idx_kernel = tf.concat(
                    [idx_batch_broadcast, idx_event_broadcast], axis=-1)

                kernel_mat = tf.scatter_nd(
                    idx_kernel,
                    updates=kernel,
                    shape=ps.cast(ps.concat([
                        kernel_batch,
                        [sub_fh * sub_fw * c_in, strides**2, c_out]
                    ],
                                            axis=0),
                                  dtype=dtype))

                kernel_mat = tf.reshape(
                    kernel_mat,
                    shape=ps.concat(
                        [ps.shape(kernel_mat)[:-2], [strides**2 * c_out]],
                        axis=0))

                kernel_mat = kernel_mat[..., tf.newaxis, :, :]
                out = tf.matmul(im_x, kernel_mat)
                broadcast_batch_shape = ps.broadcast_shape(
                    batch_shape, kernel_batch)

                if strides > 1:
                    tot_size = tf.reduce_prod(broadcast_batch_shape)
                    flat_out = tf.reshape(out,
                                          shape=ps.concat([[tot_size],
                                                           ps.shape(out)[-3:]],
                                                          axis=0))
                    out = tf.nn.depth_to_space(flat_out, block_size=strides)

                if padding == 'VALID':
                    out_height = fh + strides * (xh - 1)
                    out_width = fw + strides * (xw - 1)
                elif padding == 'SAME':
                    out_height = xh * strides
                    out_width = xw * strides

                out = out[..., truncate_top:truncate_top + out_height,
                          truncate_left:truncate_left + out_width, :]
                out = tf.reshape(
                    out,
                    shape=ps.concat([
                        broadcast_batch_shape, [out_height, out_width, c_out]
                    ],
                                    axis=0))
                return out
Ejemplo n.º 19
0
 def get_positions_shift(positions: tf.Tensor,
                         exclusive: bool) -> tf.Tensor:
     index_2d = _get_2d_index(mention_batch_positions, positions)
     return tf.cumsum(tf.scatter_nd(index_2d, mention_mask, old_shape),
                      axis=1,
                      exclusive=exclusive)
 def _call_sample_n(self, *args, **kwargs):
     xs = super()._call_sample_n(*args, **kwargs)
     return tf.scatter_nd(indices=self._indices,
                          updates=xs,
                          shape=[self._size])
Ejemplo n.º 21
0
    def linear_interpolation_operator(self, index_points):
        """ Returns the linear operator that carries out interpolation of the solution at node points.

        Parameters
        ----------
        index_points : array, shape=[..., nobs, 1]
            points at which we want to interpolate

        Returns
        -------
        O : batched array, shape = [..., nobs, mesh.npoints]
            linear operator such that O @ u interpolates u[mesh.points] onto
            index_points


        .. note::

            The matrix of the operator returned will have 2 entries per row, and so
            will be sparse when `mesh.npoints >> index_points.shape[-1]`.

        """
        # we need to find which element each of index points is in
        # we exploit the fact that the interval mesh is order and then
        # broadcast >
        is_greater = tf.squeeze(
            index_points[:, tf.newaxis, :] > self.nodes[tf.newaxis, ...])

        # find the max along each row, corresponding to an index point, to find
        # the last interval [ta[i], tb[i]] for which index_point[m] > ta[i]
        index_point_elements = tf.argmax(
            tf.cast(is_greater, tf.int32) *
            tf.range(0, self.n_nodes, 1)[tf.newaxis, :],
            axis=-1)

        ta = tf.gather(self.nodes, index_point_elements)
        tb = tf.gather(self.nodes, index_point_elements + 1)

        index_point_element_volumes = tf.gather(
            self.element_volumes, index_point_elements)[:, tf.newaxis]

        # now we have found which element the index point is in we
        # evaluate the local basis functions on these elements
        phi1 = (tb - index_points) / index_point_element_volumes
        phi2 = (index_points - ta) / index_point_element_volumes

        # indices to update Oa
        #[*zip(np.arange(index_points.shape[-2]), index_point_elements)]
        Oa_indices = tf.concat([
            np.arange(index_points.shape[-2])[:, tf.newaxis],
            index_point_elements[:, tf.newaxis]
        ],
                               axis=-1)
        Oa = tf.scatter_nd(Oa_indices,
                           tf.squeeze(phi1),
                           shape=[index_points.shape[-2], self.n_nodes])

        # incremenet the index_point_elements by one, taking advantage of the sorting
        Ob_indices = tf.concat([
            np.arange(index_points.shape[-2])[:, tf.newaxis],
            (index_point_elements + 1)[:, tf.newaxis]
        ],
                               axis=-1)
        Ob = tf.scatter_nd(Ob_indices,
                           tf.squeeze(phi2),
                           shape=[index_points.shape[-2], self.n_nodes])

        return Oa + Ob
Ejemplo n.º 22
0
def mask_mentions_and_tokens_tf(
    text_ids: tf.Tensor,
    text_mask: tf.Tensor,
    dense_span_starts: tf.Tensor,
    dense_span_ends: tf.Tensor,
    non_mention_mask_rate: float,
    mention_mask_rate: float,
    max_mlm_targets: int,
    mask_token_id: int,
    vocab_size: int,
    random_replacement_prob: float = 0.1,
    identity_replacement_prob: float = 0.1,
) -> Dict[str, tf.Tensor]:
    """Randomly masks whole mentions and random tokens up to a maximum.

  First, mentions are masked according to mention mask rate. If a mention is
  masked, all tokens in the mention are replaced by the mask token. If the
  passage has at least one mention and the mention rask rate is greater than
  zero, we mask at least one mention.

  After masking mentions, if there are fewer masked tokens than maximum mlm
  targets, we additionally mask non-mention words. TODO: If a token in a word
  is masked, all tokens in the word are masked. Some proportion of targets are
  not masked to ameliorate pretrain-finetune mismatch. If there are insufficient
  masked tokens, the target array is padded up to max targets.

  Args:
    text_ids: [seq_length] tensor with token ids.
    text_mask: [seq_length] tensor with 1s for tokens and 0 for padding.
    dense_span_starts: [seq_length] tensor with 1s for mention start positions
      and 0 otherwise.
    dense_span_ends: [seq_length] tensor with 1s for mention end positions and 0
      otherwise.
    non_mention_mask_rate: percentage of non mention tokens to be masked.
    mention_mask_rate: percentage of mentions to be masked.
    max_mlm_targets: total number of mlm targets.
    mask_token_id: token id for mask token.
    vocab_size: vocabulary size.
    random_replacement_prob: probability that to-be-masked token will be
      replaced with a random token instead of [MASK].
    identity_replacement_prob: probability that to-be-masked token will be
      replaced with itself instead of [MASK].

  Returns:
    Dictionary with masked text, mask positions, target ids, target weights.
  """
    # Mask mentions
    mention_start_positions = non_zero_1d(dense_span_starts)
    mention_end_positions = non_zero_1d(dense_span_ends)
    mention_masked_positions = mask_tokens_by_spans(text_ids,
                                                    mention_start_positions,
                                                    mention_end_positions,
                                                    mention_mask_rate,
                                                    max_mlm_targets)

    dense_is_mention = get_dense_is_inside_for_dense_spans(
        dense_span_starts, dense_span_ends)
    dense_is_not_mention = 1 - dense_is_mention
    dense_is_not_mention = dense_is_not_mention * text_mask

    # Mask non-mentions
    non_mention_start_positions = non_zero_1d(dense_is_not_mention)
    # TODO(urikz): Implement whole-word masking
    non_mention_end_positions = non_mention_start_positions
    non_mention_masked_positions = mask_tokens_by_spans(
        text_ids, non_mention_start_positions, non_mention_end_positions,
        non_mention_mask_rate,
        max_mlm_targets - tf.shape(mention_masked_positions)[0])

    # Merge masked positions for mention and non-mention tokens
    mlm_target_positions = tf.concat(
        [mention_masked_positions, non_mention_masked_positions], -1)
    n_mlm_target_positions = tf.shape(mlm_target_positions)

    # Get target IDs, weights and other features
    mlm_target_ids = tf.gather(text_ids, mlm_target_positions)
    mlm_target_weights = tf.ones(n_mlm_target_positions, dtype=tf.int64)
    mlm_target_is_mention = tf.ones(tf.shape(mention_masked_positions),
                                    dtype=tf.int64)
    seq_length = tf.shape(text_ids)[0]
    dense_is_masked = sparse_to_dense_1d(mlm_target_positions, seq_length)

    # Replace masked tokens with [MASK], random or original tokens.
    replacement_scores = tf.random.uniform(n_mlm_target_positions)
    replacement_tokens = tf.where(
        replacement_scores >
        random_replacement_prob + identity_replacement_prob,
        # replace tokens with [MASK]
        tf.cast(tf.fill(n_mlm_target_positions, value=mask_token_id),
                dtype=tf.int64),
        tf.where(
            replacement_scores > random_replacement_prob,
            # keep original
            mlm_target_ids,
            # replace with random token
            tf.random.uniform(n_mlm_target_positions,
                              maxval=vocab_size,
                              dtype=tf.int64)))
    replacement_positions = tf.expand_dims(mlm_target_positions, 1)
    # Indicies should be tf.int32 only.
    replacement_positions = tf.cast(replacement_positions, tf.int32)
    replacement_tokens = tf.scatter_nd(replacement_positions,
                                       replacement_tokens, tf.shape(text_ids))
    masked_text_ids = (text_ids * (1 - dense_is_masked) +
                       replacement_tokens * dense_is_masked)

    return {
        'masked_text_ids':
        masked_text_ids,
        'mlm_target_positions':
        dynamic_padding_1d(mlm_target_positions, max_mlm_targets),
        'mlm_target_ids':
        dynamic_padding_1d(mlm_target_ids, max_mlm_targets),
        'mlm_target_weights':
        dynamic_padding_1d(mlm_target_weights, max_mlm_targets),
        'mlm_target_is_mention':
        dynamic_padding_1d(mlm_target_is_mention, max_mlm_targets),
        'dense_is_masked':
        dense_is_masked,
    }
Ejemplo n.º 23
0
def segment_diff(x,
                 segment_ids,
                 order=1,
                 exclusive=False,
                 dtype=None,
                 name=None):
    """Computes difference of successive elements in a segment.

  For a complete description of segment_* ops see documentation of
  `tf.segment_max`. This op extends the `diff` functionality to segmented
  inputs.

  The behaviour of this op is the same as that of the op `diff` within each
  segment. The result is effectively a concatenation of the results of `diff`
  applied to each segment.

  ## Example

  ```python
    x = tf.constant([2, 5, 1, 7, 9] + [32, 10, 12, 3] + [4, 8, 5])
    segments = tf.constant([0, 0, 0, 0, 0] + [1, 1, 1, 1] + [2, 2, 2])
    # First order diff. Expected result: [3, -4, 6, 2, -22, 2, -9, 4, -3]
    dx1 = segment_diff(
        x, segment_ids=segments, order=1, exclusive=True)
    # Non-exclusive, second order diff.
    # Expected result: [2, 5, -1, 2, 8, 32, 10, -20, -7, 4, 8, 1]
    dx2 = segment_diff(
        x, segment_ids=segments, order=2, exclusive=False)
  ```

  Args:
    x: A rank 1 `Tensor` of any dtype for which arithmetic operations are
      permitted.
    segment_ids: A `Tensor`. Must be one of the following types: int32, int64. A
      1-D tensor whose size is equal to the size of `x`. Values should be sorted
      and can be repeated.
    order: Positive Python int. The order of the difference to compute. `order =
      1` corresponds to the difference between successive elements.
      Default value: 1
    exclusive: Python bool. See description above.
      Default value: False
    dtype: Optional `tf.Dtype`. If supplied, the dtype for `x` to use when
      converting to `Tensor`.
      Default value: None which maps to the default dtype inferred by TF.
    name: Python `str` name prefixed to Ops created by this class.
      Default value: None which is mapped to the default name 'segment_diff'.

  Returns:
    diffs: A `Tensor` of the same dtype as `x`. Assuming that each segment is
      of length greater than or equal to order, if `exclusive` is True,
      then the size is `n-order*k` where `n` is the size of x,
      `k` is the number of different segment ids supplied if `segment_ids` is
      not None or 1 if `segment_ids` is None. If any of the segments is of
      length less than the order, then the size is:
      `n-sum(min(order, length(segment_j)), j)` where the sum is over segments.
      If `exclusive` is False, then the size is `n`.
  """
    with tf.compat.v1.name_scope(name, default_name='segment_diff',
                                 values=[x]):
        x = tf.convert_to_tensor(x, dtype=dtype)
        raw_diffs = diff_ops.diff(x, order=order, exclusive=exclusive)
        if segment_ids is None:
            return raw_diffs
        # If segment ids are supplied, raw_diffs are incorrect at locations:
        # p, p+1, ... min(p+order-1, m_p-1) where p is the index of the first
        # element of a segment other than the very first segment (which is
        # already correct). m_p is the segment length.
        # Find positions where the segments begin.
        has_segment_changed = tf.concat(
            [[False],
             tf.not_equal(segment_ids[1:] - segment_ids[:-1], 0)],
            axis=0)
        # Shape [k, 1]
        segment_start_index = tf.cast(tf.where(has_segment_changed),
                                      dtype=tf.int32)
        segment_end_index = tf.concat([
            tf.reshape(segment_start_index, [-1])[1:], [tf.size(segment_ids)]
        ],
                                      axis=0)
        segment_end_index = tf.reshape(segment_end_index, [-1, 1])
        # The indices of locations that need to be adjusted. This needs to be
        # constructed in steps. First we generate p, p+1, ... p+order-1.
        # Shape [num_segments-1, order]
        fix_indices = (segment_start_index +
                       tf.range(order, dtype=segment_start_index.dtype))
        in_bounds = tf.where(fix_indices < segment_end_index)
        # Keep only the ones in bounds.
        fix_indices = tf.reshape(tf.gather_nd(fix_indices, in_bounds), [-1, 1])

        needs_fix = tf.scatter_nd(
            fix_indices,
            # Unfortunately, scatter_nd doesn't support bool on GPUs so we need to
            # do ints here and then convert to bool.
            tf.reshape(tf.ones_like(fix_indices, dtype=tf.int32), [-1]),
            shape=tf.shape(x))
        # If exclusive is False, then needs_fix means we need to replace the values
        # in raw_diffs at those locations with the values in x.
        needs_fix = tf.cast(needs_fix, dtype=tf.bool)
        if not exclusive:
            return tf.where(needs_fix, x, raw_diffs)

        # If exclusive is True, we have to be more careful. The raw_diffs
        # computation has removed the first 'order' elements. After removing the
        # corresponding elements from needs_fix, we use it to remove the elements
        # from raw_diffs.
        return tf.boolean_mask(raw_diffs, tf.logical_not(needs_fix[order:]))
Ejemplo n.º 24
0
def add_entity_tokens(
    text_ids: tf.Tensor,
    text_mask: tf.Tensor,
    mention_mask: tf.Tensor,
    mention_batch_positions: tf.Tensor,
    mention_start_positions: tf.Tensor,
    mention_end_positions: tf.Tensor,
    new_length: int,
    mlm_target_positions: Optional[tf.Tensor] = None,
    mlm_target_weights: Optional[tf.Tensor] = None,
    entity_start_token_id: int = default_values.ENTITY_START_TOKEN,
    entity_end_token_id: int = default_values.ENTITY_END_TOKEN,
) -> Dict[str, tf.Tensor]:
    """Adds entity start / end tokens around mentions.

  Inserts `entity_start_token_id` and `entity_end_token_id` tokens around each
  mention and update mention_start_positions / mention_end_positions to point
  to these tokens.

  New text length will be `new_length` and texts will be truncated if nessesary.
  If a mention no longer fits into the new text, its mask (`mention_mask`) will
  be set to 0.

  The function can also update MLM position and weights (`mlm_target_positions`
  and `mlm_target_weights`) if these arguments are provided. Similarly to
  mentions, if MLM position no longer fits into the new text, its mask
  (`mlm_target_weights`) will be set to 0.

  Args:
    text_ids: [seq_length] tensor with token ids.
    text_mask: [seq_length] tensor with 1s for tokens and 0 for padding.
    mention_mask: [n_mentions] mask indicating whether a mention is a padding.
    mention_batch_positions: [n_mentions] sample ID of a mention in the batch.
    mention_start_positions: [n_mentions] position of a mention first token
      within a sample.
    mention_end_positions: [n_mentions] position of a mention last token within
      a sample.
    new_length: new length of text after entity tokens will be added.
    mlm_target_positions: [batch_size, max_mlm_targets] positions of tokens to
      be used for MLM task.
    mlm_target_weights: [batch_size, max_mlm_targets] mask indicating whether
      `mlm_target_positions` is a padding.
    entity_start_token_id: token to be used as entity start token.
    entity_end_token_id: token to be used as entity end token.

  Returns:
    New text_ids and text_mask, updated mentions positions including
    mention_start_positions, mention_end_positions and mention_mask.
    Returns updated mlm_target_positions and mlm_target_weights if they were
    provided as arguments.
  """
    batch_size = tf.shape(text_ids)[0]
    old_length = tf.shape(text_ids)[1]
    new_shape = (batch_size, new_length)

    mentions_fit_mask = compute_which_mentions_fit_with_entity_tokens(
        mention_mask,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        batch_size,
        old_length,
        new_length,
    )
    # Ignore mentions that does not fit into new texts.
    new_mention_mask = mention_mask * mentions_fit_mask
    mention_start_positions = mention_start_positions * new_mention_mask
    mention_end_positions = mention_end_positions * new_mention_mask

    positions = compute_positions_shift_with_entity_tokens(
        new_mention_mask, mention_batch_positions, mention_start_positions,
        mention_end_positions, batch_size, old_length)

    def get_2d_index(positions: tf.Tensor) -> tf.Tensor:
        return _get_2d_index(mention_batch_positions, positions)

    def get_new_positions(old_positions: tf.Tensor) -> tf.Tensor:
        index_2d = get_2d_index(old_positions)
        return tf.gather_nd(positions, index_2d)

    new_mention_start_positions = get_new_positions(
        mention_start_positions) - 1
    new_mention_start_positions = new_mention_start_positions * new_mention_mask
    new_mention_end_positions = get_new_positions(mention_end_positions) + 1
    new_mention_end_positions = new_mention_end_positions * new_mention_mask

    if mlm_target_positions is not None:
        if mlm_target_weights is None:
            raise ValueError('`mlm_target_weights` must be specified if '
                             '`mlm_target_positions` is provided.')
        mlm_target_positions = tf.gather(positions,
                                         mlm_target_positions,
                                         batch_dims=1)
        mlm_target_positions_within_len = tf.less(mlm_target_positions,
                                                  new_length)
        mlm_target_positions_within_len = tf.cast(
            mlm_target_positions_within_len, mlm_target_weights.dtype)
        mlm_target_weights = mlm_target_weights * mlm_target_positions_within_len
        # Zero-out positions for pad MLM targets
        mlm_target_positions = mlm_target_positions * mlm_target_weights

    # Cut texts that are longer than `new_length`
    text_within_new_length = tf.less(positions, new_length)
    text_ids = text_ids * tf.cast(text_within_new_length, text_ids.dtype)
    text_mask = text_mask * tf.cast(text_within_new_length, text_mask.dtype)
    positions = tf.minimum(positions, new_length - 1)

    # Prepare 2D index for tokens positions in the next text_ids and text_mask.
    # Note that we use flat 2D index and flat values
    # (e.g. `tf.reshape(text_ids, [-1])`) since `tf.scatter_nd` does not support
    # batch dimension.
    batch_positions = _batched_range(old_length, batch_size, 1,
                                     positions.dtype)
    batch_positions = tf.reshape(batch_positions, [-1])
    text_index_2d = _get_2d_index(batch_positions, tf.reshape(positions, [-1]))

    new_text_ids = tf.scatter_nd(text_index_2d, tf.reshape(text_ids, [-1]),
                                 new_shape)
    new_text_mask = tf.scatter_nd(text_index_2d, tf.reshape(text_mask, [-1]),
                                  new_shape)

    # Insert entity start / end tokens into the new text_ids and text_mask.
    new_mention_start_positions_2d = get_2d_index(new_mention_start_positions)
    new_mention_end_positions_2d = get_2d_index(new_mention_end_positions)

    new_text_ids = tf.tensor_scatter_nd_add(
        new_text_ids, new_mention_start_positions_2d,
        new_mention_mask * entity_start_token_id)
    new_text_ids = tf.tensor_scatter_nd_add(
        new_text_ids, new_mention_end_positions_2d,
        new_mention_mask * entity_end_token_id)

    new_mention_mask = tf.cast(new_mention_mask, dtype=text_mask.dtype)
    new_text_mask = tf.tensor_scatter_nd_add(new_text_mask,
                                             new_mention_start_positions_2d,
                                             new_mention_mask)
    new_text_mask = tf.tensor_scatter_nd_add(new_text_mask,
                                             new_mention_end_positions_2d,
                                             new_mention_mask)

    features = {
        'text_ids': new_text_ids,
        'text_mask': new_text_mask,
        'mention_start_positions': new_mention_start_positions,
        'mention_end_positions': new_mention_end_positions,
        'mention_mask': new_mention_mask,
    }

    if mlm_target_positions is not None:
        features['mlm_target_weights'] = mlm_target_weights
        features['mlm_target_positions'] = mlm_target_positions

    return features
Ejemplo n.º 25
0
def _update_batch(ind, b_update, b=None, batch_shape=None):
    """Updates a batch of `i`, `j`, `q1[i]` or `q0[j]`."""
    updates = tf.gather_nd(b_update, ind)
    if b is None:
        return tf.scatter_nd(ind, updates, batch_shape)
    return tf.tensor_scatter_nd_update(b, ind, updates)
Ejemplo n.º 26
0
def _cap_positives_mask(untiled_mask, diagonal_mask, num_views, positives_cap):
    r"""Cap positives in the provided untiled_mask.

      'positives_cap' specifies the maximum number of positives *other* than
      augmentations of the anchor. Positives will be evenly sampled from all
      views.

  Args:
    untiled_mask: Tensor of shape [local_batch_size, global_batch_size] that has
      entry (r, c) == 1 if feature entries in rows r and c are from the same
      class. Else (r, c) == 0.
    diagonal_mask: Tensor with the same shape as `untiled_mask`. When
      local_batch_size == global_batch_size this is just an identity matrix.
      Otherwise, it is an identity matrix of size `local_batch_size` that is
      padded with 0's in the 2nd dimension to match the target shape. This is
      used to indicate where the anchor views exist in the global batch of
      views.
    num_views: Integer number of total views.
    positives_cap: Integer maximum number of positives *other* than
      augmentations of anchor. Infinite if < 0. Must be multiple of num_views.
      Including augmentations, a maximum of (positives_cap + num_views - 1)
      positives is possible. This parameter modifies the contrastive numerator
      by selecting which positives are present in the summation, and which
      positives contribure to the denominator if denominator_mode ==
      enums.LossDenominatorMode.ALL.

  Returns:
    A tf.Tensor with the modified `untiled_mask`.
  """
    untiled_mask_no_diagonal = tf.math.minimum(untiled_mask,
                                               1. - diagonal_mask)
    untiled_positives_per_anchor = positives_cap // num_views

    # Grabs top-k positives from each row in the mask. Can end up with negatives
    # incorrectly marked as positives if fewer than `untiled_positives_per_anchor`
    # exist in any row of `untiled_mask_no_diagonal`. However, these false
    # positives wil be masked out before the function returns.
    _, top_k_col_idx = tf.math.top_k(untiled_mask_no_diagonal,
                                     untiled_positives_per_anchor)
    top_k_row_idx = tf.expand_dims(tf.range(tf.shape(untiled_mask)[0]), axis=1)

    # Construct |top_k_idx|, a tensor of shape
    # [untiled_positives_per_anchor * local_batch_size, 2]. Each row represents
    # the 2D index in a
    # [local_batch_size * num_anchor_views, global_batch_size * num_views] size
    # tensor which holds a positive; all others are negatives.
    top_k_idx = tf.reshape(
        tf.stack([
            tf.tile(top_k_row_idx,
                    (1, untiled_positives_per_anchor)), top_k_col_idx
        ],
                 axis=-1), (-1, 2))

    # Construct |untiled_mask|. Sets positives to 1 according to top_k_idx
    # above.
    untiled_mask_capped = tf.scatter_nd(
        top_k_idx,
        tf.ones(shape=tf.shape(top_k_idx)[0],
                dtype=untiled_mask_no_diagonal.dtype),
        untiled_mask_no_diagonal.shape)
    untiled_mask_capped = tf.math.maximum(untiled_mask_capped, diagonal_mask)
    return untiled_mask * untiled_mask_capped
Ejemplo n.º 27
0
def prepare_grid(*, times, time_step, dtype, num_time_steps=None,
                 times_grid=None):
  """Prepares grid of times for path generation.

  Args:
    times:  Rank 1 `Tensor` of increasing positive real values. The times at
      which the path points are to be evaluated.
    time_step: Rank 0 real `Tensor`. Maximal distance between points in
      resulting grid.
    dtype: `tf.Dtype` of the input and output `Tensor`s.
    num_time_steps: Number of points on the grid. If suppied, a uniform grid
      is constructed for `[time_step, times[-1] - time_step]` consisting of
      max(0, num_time_steps - len(times)) points that is then concatenated with
      times. This parameter guarantees the number of points on the time grid
      is `max(len(times), num_time_steps)` and that `times` are included to the
      grid.
      Default value: `None`, which means that a uniform grid is created.
       containing all points from 'times` and the uniform grid of points between
       `[0, times[-1]]` with grid size equal to `time_step`.
    times_grid: An optional rank 1 `Tensor` representing time discretization
      grid. If `times` are not on the grid, then the nearest points from the
      grid are used.
      Default value: `None`, which means that times grid is computed using
      `time_step` and `num_time_steps`.

  Returns:
    Tuple `(all_times, mask, time_indices)`.
    `all_times` is a 1-D real `Tensor`. If `num_time_steps` is supplied the
      shape of the output is `max(num_time_steps, len(times))`. Otherwise
      consists of all points from 'times` and the uniform grid of points between
      `[0, times[-1]]` with grid size equal to `time_step`.
    `mask` is a boolean 1-D `Tensor` of the same shape as 'all_times', showing
      which elements of 'all_times' correspond to THE values from `times`.
      Guarantees that times[0]=0 and mask[0]=False.
    `time_indices`. An integer `Tensor` of the same shape as `times` indicating
    `times` indices in `all_times`.
  """
  if times_grid is None:
    if num_time_steps is None:
      all_times, time_indices = _grid_from_time_step(
          times=times, time_step=time_step, dtype=dtype)
    else:
      all_times, time_indices = _grid_from_num_times(
          times=times, time_step=time_step, num_time_steps=num_time_steps)
  else:
    all_times = times_grid
    time_indices = tf.searchsorted(times_grid, times)
    # Adjust indices to bring `times` closer to `times_grid`.
    times_diff_1 = tf.gather(times_grid, time_indices) - times
    times_diff_2 = tf.gather(
        times_grid, tf.math.maximum(time_indices-1, 0)) - times
    time_indices = tf.where(
        tf.math.abs(times_diff_2) > tf.math.abs(times_diff_1),
        time_indices,
        tf.math.maximum(time_indices - 1, 0))
  # Create a boolean mask to identify the iterations that have to be recorded.
  # Use `tf.scatter_nd` because it handles duplicates. Also we first create
  # an int64 Tensor and then create a boolean mask becase scatter_nd with
  # booleans is currently not supported on GPUs.
  mask = tf.scatter_nd(
      indices=tf.expand_dims(tf.cast(time_indices, dtype=tf.int64), axis=1),
      updates=tf.fill(tf.shape(times), 1),
      shape=tf.shape(all_times, out_type=tf.int64))
  mask = tf.where(mask > 0, True, False)

  return all_times, mask, time_indices
Ejemplo n.º 28
0
def _delete_tf(tensor, idx, axis=0):
    """Deletes from a tensor along an axis at the given index."""
    n = tf.shape(tensor)[axis]
    t = tf.ones_like(idx, dtype=tf.bool)
    m = ~tf.scatter_nd(tf.expand_dims(idx, 1), t, [n])
    return tf.boolean_mask(tensor, m, axis=axis)
Ejemplo n.º 29
0
def sparse_to_dense_1d(sparse_values: tf.Tensor, seq_length: int):
    """Convert sparse tensor ([0, 1, 4]) to dense tensor ([1, 1, 0, 0, 1])."""
    updates = tf.fill(tf.shape(sparse_values), value=1)
    updates = tf.cast(updates, sparse_values.dtype)
    return tf.scatter_nd(tf.expand_dims(sparse_values, 1), updates,
                         [seq_length])