def get_dense_span_ends_from_starts(dense_span_starts: tf.Tensor,
                                    dense_span_ends: tf.Tensor) -> tf.Tensor:
    """For every mention start positions finds the corresponding end position."""
    seq_len = tf.shape(dense_span_starts)[0]
    start_pos = tf.cast(tf.where(tf.equal(dense_span_starts, 1)), tf.int32)
    end_pos = tf.cast(tf.squeeze(tf.where(tf.equal(dense_span_ends, 1)), 1),
                      tf.int32)
    dense_span_ends_from_starts = tf.zeros(seq_len, dtype=tf.int32)
    dense_span_ends_from_starts = tf.tensor_scatter_nd_add(
        dense_span_ends_from_starts, start_pos, end_pos)
    return dense_span_ends_from_starts
            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                hessian_column_with_l2 = sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord),
                    adjoint_a=True)

                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)

                # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
                # order to conform to `hess_matmul_x_update`.
                n = tf.rank(hessian_column_with_l2)
                perm = tf.roll(tf.range(n), shift=1, axis=0)
                hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2,
                                                      perm=perm)

                # Update the entire batch at `coord` even if `delta` may be 0 at some
                # batch coordinates. In those cases, adding `delta` is a no-op.
                x_update = tf.tensor_scatter_nd_add(x_update, [[coord]],
                                                    [delta])

                with tf.control_dependencies([x_update]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)

                    # Hint that loop vars retain the same shape.
                    x_update_diff_norm_sq_.set_shape(
                        x_update_diff_norm_sq_.shape.merge_with(
                            x_update_diff_norm_sq.shape))
                    hess_matmul_x_update_.set_shape(
                        hess_matmul_x_update_.shape.merge_with(
                            hess_matmul_x_update.shape))

                    return [
                        x_update_diff_norm_sq_, x_update, hess_matmul_x_update_
                    ]
예제 #3
0
        def loop_body(auc_prev, p0_prev, p1_prev, i, j):
            """Body of the loop to integrate over the ROC/PR curves."""

            # We will align intervals from q0 and q1 by moving from end to beginning,
            # stopping with whatever boundary we encounter first. This boundary is
            # `b`. We find `a` by continuing the search from `b` to the left.
            b, i, j = _get_endpoint_b(i, j, q1, q0, batch_shape)

            # Get q1[i] and q1[i+1]
            ind = tf.where(tf.greater_equal(i, 0))
            q1_i = _get_q_slice(q1, i, ind, batch_shape=batch_shape)
            ip1 = tf.minimum(i + 1, n_q1 - 1)
            q1_ip1 = _get_q_slice(q1, ip1, ind, batch_shape=batch_shape)

            # Get q0[j] and q0[j+1]
            ind = tf.where(tf.greater_equal(j, 0))
            q0_j = _get_q_slice(q0, j, ind, batch_shape=batch_shape)
            jp1 = tf.minimum(j + 1, n_q0 - 1)
            q0_jp1 = _get_q_slice(q0, jp1, ind, batch_shape=batch_shape)

            a = _get_endpoint_a(i, j, q1_i, q0_j, batch_shape)

            # Calculate the proportion [a, b) represents of [q1[i], q1[i+1]).
            d1 = _get_interval_proportion(i, n_q1, q1_i, q1_ip1, a, b,
                                          batch_shape)

            # Calculate the proportion [a, b) represents of [q1[i], q1[i+1]).
            d0 = _get_interval_proportion(j, n_q0, q0_j, q0_jp1, a, b,
                                          batch_shape)

            # Notice that because we assumed within bucket values are distributed
            # uniformly, we end up with something which is identical to the
            # trapezoidal rule: definite_integral += (b - a) * (f(a) + f(b)) / 2.
            if curve == 'ROC':
                auc = auc_prev + d0 * (p1_prev + d1 / 2.)
            else:
                total_scaled_delta = n0 * d0 + n1 * d1
                total_scaled_cdf_at_b = n0 * p0_prev + n1 * p1_prev

                def get_auprc_update():
                    proportion = (total_scaled_cdf_at_b /
                                  (total_scaled_cdf_at_b + total_scaled_delta))
                    definite_integral = (
                        (n1 / tf.square(total_scaled_delta)) *
                        (d1 * total_scaled_delta + tf.math.log(proportion) *
                         (d1 * total_scaled_cdf_at_b -
                          p1_prev * total_scaled_delta)))
                    return d1 * definite_integral

                # Values should be non-negative and we use > 0.0 rather than != 0.0 to
                # catch possible numerical imprecision.
                delta_gt_0 = tf.greater(total_scaled_delta, 0.)
                cdf_gt_0 = tf.greater(total_scaled_cdf_at_b, 0.)
                d1_gt_0 = tf.greater(d1, 0.)
                ind = tf.where(delta_gt_0 & cdf_gt_0 & d1_gt_0)
                auc_update = tf.gather_nd(get_auprc_update(), ind)
                auc = tf.tensor_scatter_nd_add(auc_prev, ind, auc_update)

            # TODO(jvdillon): In calculating AUROC and AUPRC, we probably should
            # resolve ties randomly, making the following eight states possible (where
            # e = 1 means we expected a positive trial and e = 0 means we expected a
            # negative trial):
            #
            #   P(y = 1 | pi(x) > delta)
            #   P(y = 1 | pi(x) = delta, e = 1) 0.5
            #   P(y = 1 | pi(x) = delta, e = 0) 0.5
            #   P(y = 1 | pi(x) < delta)
            #
            #   P(y = 0 | pi(x) > delta)
            #   P(y = 0 | pi(x) = delta, e = 1) 0.5
            #   P(y = 0 | pi(x) = delta, e = 0) 0.5
            #   P(y = 0 | pi(x) < delta)
            #
            # This makes the math a bit harder and its not clear this adds much,
            # especially when we're assuming piecewise uniformity.

            # Accumulate this mass (d1, d0) for the next iteration.
            p1 = p1_prev + d1
            p0 = p0_prev + d0

            return auc, p0, p1, i, j
예제 #4
0
        def collater_fn(batch: Dict[Text, tf.Tensor]) -> Dict[Text, tf.Tensor]:
            """Collater function for mention classification task. See BaseTask."""

            new_batch = {}

            # Sample mentions uniformly across batch
            mention_mask = tf.reshape(batch['mention_mask'],
                                      [n_candidate_mentions])
            sample_scores = tf.random.uniform(
                shape=[n_candidate_mentions]) * tf.cast(
                    mention_mask, tf.float32)

            mention_target_indices = tf.reshape(
                batch['mention_target_indices'], [bsz])

            # We want to make sure that the target mentions always have a priority
            # when we sample `max_batch_mentions` out of all available mentions.
            # Additionally, we want these target mentions to be in the same order as
            # their samples. In other words, we want the first sampled mention to be
            # target mention from the first sample, the second sampled mention to be
            # tagret mention from the second sample, etc.

            # Positions of target mentions in the flat array
            mention_target_indices_flat = (tf.cast(
                tf.range(bsz) * max_mentions_per_sample,
                mention_target_indices.dtype) + mention_target_indices)
            # These extra score makes sure that target mentions have a priority and
            # will be sampled in the correct order.
            mention_target_extra_score_flat = tf.cast(
                tf.reverse(tf.range(bsz) + 1, axis=[0]), tf.float32)
            # The model assumes that there is only ONE target mention per sample.
            # Moreover,we want to select them according to the order of samples:
            # target mention from sample 0, target mention from sample 1, ..., etc.
            sample_scores = tf.tensor_scatter_nd_add(
                sample_scores, tf.expand_dims(mention_target_indices_flat, 1),
                mention_target_extra_score_flat)

            sampled_indices = tf.math.top_k(sample_scores,
                                            max_batch_mentions,
                                            sorted=True).indices

            # Double-check target mentions were selected correctly.
            assert_op = tf.assert_equal(
                sampled_indices[:bsz],
                tf.cast(mention_target_indices_flat, sampled_indices.dtype))

            with tf.control_dependencies([assert_op]):
                mention_mask = tf.gather(mention_mask, sampled_indices)
            dtype = batch['mention_start_positions'].dtype
            mention_start_positions = tf.gather(
                tf.reshape(batch['mention_start_positions'],
                           [n_candidate_mentions]), sampled_indices)
            mention_end_positions = tf.gather(
                tf.reshape(batch['mention_end_positions'],
                           [n_candidate_mentions]), sampled_indices)

            mention_batch_positions = tf.gather(
                tf.repeat(tf.range(bsz, dtype=dtype), max_mentions_per_sample),
                sampled_indices)

            new_batch['text_ids'] = batch['text_ids']
            new_batch['text_mask'] = batch['text_mask']
            new_batch['classifier_target'] = tf.reshape(
                batch['target'], [bsz, config.max_num_labels_per_sample])
            new_batch['classifier_target_mask'] = tf.reshape(
                batch['target_mask'], [bsz, config.max_num_labels_per_sample])

            new_batch['mention_mask'] = mention_mask
            new_batch['mention_start_positions'] = mention_start_positions
            new_batch['mention_end_positions'] = mention_end_positions
            new_batch['mention_batch_positions'] = mention_batch_positions
            new_batch['mention_target_indices'] = tf.range(bsz, dtype=dtype)

            if config.get('max_length_with_entity_tokens') is not None:
                batch_with_entity_tokens = mention_preprocess_utils.add_entity_tokens(
                    text_ids=new_batch['text_ids'],
                    text_mask=new_batch['text_mask'],
                    mention_mask=new_batch['mention_mask'],
                    mention_batch_positions=new_batch[
                        'mention_batch_positions'],
                    mention_start_positions=new_batch[
                        'mention_start_positions'],
                    mention_end_positions=new_batch['mention_end_positions'],
                    new_length=config.max_length_with_entity_tokens,
                )
                # Update `text_ids`, `text_mask`, `mention_mask`, `mention_*_positions`
                new_batch.update(batch_with_entity_tokens)
                # Update `max_length`
                max_length = config.max_length_with_entity_tokens
            else:
                max_length = encoder_config.max_length

            new_batch['mention_target_batch_positions'] = tf.gather(
                new_batch['mention_batch_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_start_positions'] = tf.gather(
                new_batch['mention_start_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_end_positions'] = tf.gather(
                new_batch['mention_end_positions'],
                new_batch['mention_target_indices'])
            new_batch['mention_target_weights'] = tf.ones(bsz)

            # Fake IDs -- some encoders (ReadTwice) need them
            new_batch['mention_target_ids'] = tf.zeros(bsz)

            new_batch['segment_ids'] = tf.zeros_like(new_batch['text_ids'])

            position_ids = tf.expand_dims(tf.range(max_length, dtype=dtype),
                                          axis=0)
            new_batch['position_ids'] = tf.tile(position_ids, (bsz, 1))

            return new_batch
def add_entity_tokens(
    text_ids: tf.Tensor,
    text_mask: tf.Tensor,
    mention_mask: tf.Tensor,
    mention_batch_positions: tf.Tensor,
    mention_start_positions: tf.Tensor,
    mention_end_positions: tf.Tensor,
    new_length: int,
    mlm_target_positions: Optional[tf.Tensor] = None,
    mlm_target_weights: Optional[tf.Tensor] = None,
    entity_start_token_id: int = default_values.ENTITY_START_TOKEN,
    entity_end_token_id: int = default_values.ENTITY_END_TOKEN,
) -> Dict[str, tf.Tensor]:
    """Adds entity start / end tokens around mentions.

  Inserts `entity_start_token_id` and `entity_end_token_id` tokens around each
  mention and update mention_start_positions / mention_end_positions to point
  to these tokens.

  New text length will be `new_length` and texts will be truncated if nessesary.
  If a mention no longer fits into the new text, its mask (`mention_mask`) will
  be set to 0.

  The function can also update MLM position and weights (`mlm_target_positions`
  and `mlm_target_weights`) if these arguments are provided. Similarly to
  mentions, if MLM position no longer fits into the new text, its mask
  (`mlm_target_weights`) will be set to 0.

  Args:
    text_ids: [seq_length] tensor with token ids.
    text_mask: [seq_length] tensor with 1s for tokens and 0 for padding.
    mention_mask: [n_mentions] mask indicating whether a mention is a padding.
    mention_batch_positions: [n_mentions] sample ID of a mention in the batch.
    mention_start_positions: [n_mentions] position of a mention first token
      within a sample.
    mention_end_positions: [n_mentions] position of a mention last token within
      a sample.
    new_length: new length of text after entity tokens will be added.
    mlm_target_positions: [batch_size, max_mlm_targets] positions of tokens to
      be used for MLM task.
    mlm_target_weights: [batch_size, max_mlm_targets] mask indicating whether
      `mlm_target_positions` is a padding.
    entity_start_token_id: token to be used as entity start token.
    entity_end_token_id: token to be used as entity end token.

  Returns:
    New text_ids and text_mask, updated mentions positions including
    mention_start_positions, mention_end_positions and mention_mask.
    Returns updated mlm_target_positions and mlm_target_weights if they were
    provided as arguments.
  """
    batch_size = tf.shape(text_ids)[0]
    old_length = tf.shape(text_ids)[1]
    new_shape = (batch_size, new_length)

    mentions_fit_mask = compute_which_mentions_fit_with_entity_tokens(
        mention_mask,
        mention_batch_positions,
        mention_start_positions,
        mention_end_positions,
        batch_size,
        old_length,
        new_length,
    )
    # Ignore mentions that does not fit into new texts.
    new_mention_mask = mention_mask * mentions_fit_mask
    mention_start_positions = mention_start_positions * new_mention_mask
    mention_end_positions = mention_end_positions * new_mention_mask

    positions = compute_positions_shift_with_entity_tokens(
        new_mention_mask, mention_batch_positions, mention_start_positions,
        mention_end_positions, batch_size, old_length)

    def get_2d_index(positions: tf.Tensor) -> tf.Tensor:
        return _get_2d_index(mention_batch_positions, positions)

    def get_new_positions(old_positions: tf.Tensor) -> tf.Tensor:
        index_2d = get_2d_index(old_positions)
        return tf.gather_nd(positions, index_2d)

    new_mention_start_positions = get_new_positions(
        mention_start_positions) - 1
    new_mention_start_positions = new_mention_start_positions * new_mention_mask
    new_mention_end_positions = get_new_positions(mention_end_positions) + 1
    new_mention_end_positions = new_mention_end_positions * new_mention_mask

    if mlm_target_positions is not None:
        if mlm_target_weights is None:
            raise ValueError('`mlm_target_weights` must be specified if '
                             '`mlm_target_positions` is provided.')
        mlm_target_positions = tf.gather(positions,
                                         mlm_target_positions,
                                         batch_dims=1)
        mlm_target_positions_within_len = tf.less(mlm_target_positions,
                                                  new_length)
        mlm_target_positions_within_len = tf.cast(
            mlm_target_positions_within_len, mlm_target_weights.dtype)
        mlm_target_weights = mlm_target_weights * mlm_target_positions_within_len
        # Zero-out positions for pad MLM targets
        mlm_target_positions = mlm_target_positions * mlm_target_weights

    # Cut texts that are longer than `new_length`
    text_within_new_length = tf.less(positions, new_length)
    text_ids = text_ids * tf.cast(text_within_new_length, text_ids.dtype)
    text_mask = text_mask * tf.cast(text_within_new_length, text_mask.dtype)
    positions = tf.minimum(positions, new_length - 1)

    # Prepare 2D index for tokens positions in the next text_ids and text_mask.
    # Note that we use flat 2D index and flat values
    # (e.g. `tf.reshape(text_ids, [-1])`) since `tf.scatter_nd` does not support
    # batch dimension.
    batch_positions = _batched_range(old_length, batch_size, 1,
                                     positions.dtype)
    batch_positions = tf.reshape(batch_positions, [-1])
    text_index_2d = _get_2d_index(batch_positions, tf.reshape(positions, [-1]))

    new_text_ids = tf.scatter_nd(text_index_2d, tf.reshape(text_ids, [-1]),
                                 new_shape)
    new_text_mask = tf.scatter_nd(text_index_2d, tf.reshape(text_mask, [-1]),
                                  new_shape)

    # Insert entity start / end tokens into the new text_ids and text_mask.
    new_mention_start_positions_2d = get_2d_index(new_mention_start_positions)
    new_mention_end_positions_2d = get_2d_index(new_mention_end_positions)

    new_text_ids = tf.tensor_scatter_nd_add(
        new_text_ids, new_mention_start_positions_2d,
        new_mention_mask * entity_start_token_id)
    new_text_ids = tf.tensor_scatter_nd_add(
        new_text_ids, new_mention_end_positions_2d,
        new_mention_mask * entity_end_token_id)

    new_mention_mask = tf.cast(new_mention_mask, dtype=text_mask.dtype)
    new_text_mask = tf.tensor_scatter_nd_add(new_text_mask,
                                             new_mention_start_positions_2d,
                                             new_mention_mask)
    new_text_mask = tf.tensor_scatter_nd_add(new_text_mask,
                                             new_mention_end_positions_2d,
                                             new_mention_mask)

    features = {
        'text_ids': new_text_ids,
        'text_mask': new_text_mask,
        'mention_start_positions': new_mention_start_positions,
        'mention_end_positions': new_mention_end_positions,
        'mention_mask': new_mention_mask,
    }

    if mlm_target_positions is not None:
        features['mlm_target_weights'] = mlm_target_weights
        features['mlm_target_positions'] = mlm_target_positions

    return features
    def _solve_dirichlet(self):
        # check the boundary conditions are correct
        if self.domain.boundary.boundary_condition_type != 'Dirichlet':
            raise ValueError(
                "_solve_dirichlet must be used with Dirichlet boundary conditions."
            )

        A, b = self._assemble()

        # work out if any batching needs to be done
        if len(self.batch_shape) == 0:
            # add an extra leading dimension onto A
            A = A[tf.newaxis, ...]
            batch_shape = [1]  # false batch shape, squeezed out by end
        else:
            batch_shape = self.batch_shape

        # get the interior of A
        global_stiffness_interior_indices = [
            *itertools.product(self.mesh.interior_node_indices, repeat=2)
        ]

        Ainterior = tf.map_fn(
            lambda x: tf.reshape(
                tf.gather_nd(x, global_stiffness_interior_indices), [
                    len(self.mesh.interior_node_indices),
                    len(self.mesh.interior_node_indices)
                ]), A)

        b_interior = tf.gather_nd(b, [
            *zip(self.mesh.interior_node_indices,
                 [0] * len(self.mesh.interior_node_indices))
        ])

        interior_bound_indices = [
            *itertools.product(self.mesh.interior_node_indices,
                               self.mesh.boundary_node_indices)
        ]

        Aint_bnd = tf.map_fn(
            lambda x: tf.reshape(tf.gather_nd(x, interior_bound_indices), [
                len(self.mesh.interior_node_indices),
                len(self.mesh.boundary_node_indices)
            ]), A)

        bnd_node_indices = np.array(self.mesh.boundary_node_indices,
                                    dtype=np.intp)
        int_node_indices = np.array(self.mesh.interior_node_indices,
                                    dtype=np.intp)

        # get the value on the boundary
        g = self.domain.boundary.g

        # convert the stiffness matrices to operators for batch matmul
        Ainterior_op = tf.linalg.LinearOperatorFullMatrix(Ainterior)
        Aint_bnd_op = tf.linalg.LinearOperatorFullMatrix(Aint_bnd)

        # add the fixed dirichlet conditions to sol
        # ToDo: Batch boundary values
        sol = tf.scatter_nd(bnd_node_indices[:, None],
                            g,
                            shape=[self.mesh.n_nodes, 1])

        b_ = b_interior[..., tf.newaxis] - Aint_bnd_op.matmul(g)
        sol_interior = Ainterior_op.solve(b_)

        # sol_interior has a batched shape [b, n_interior_nodes, 1]
        return tf.squeeze(
            tf.map_fn(
                lambda x: tf.tensor_scatter_nd_add(
                    sol, int_node_indices[:, None], x), sol_interior)
        )[..., tf.
          newaxis]  # kills pesduo-batch dimensions, but keeps output a vector