Пример #1
0
def split_latents(x, minibatch_size=1, hy_ncut=1):
    # x: [b, dim]
    # b = minibatch_size
    # b = x.get_shape().as_list()[0]
    if hy_ncut == 0:
        return [x]
    b = tf.shape(x)[0]
    dim = x.get_shape().as_list()[1]
    split_idx = tf.random.uniform(shape=[b, hy_ncut],
                                  maxval=dim + 1,
                                  dtype=tf.int32)
    split_idx = tf.sort(split_idx, axis=-1)
    idx_range = tf.tile(tf.range(dim)[tf.newaxis, :], [b, 1])
    masks = []
    mask_last = tf.zeros([b, dim], dtype=tf.float32)
    for i in range(hy_ncut):
        mask_tmp = tf.cast(idx_range < split_idx[:, i:i + 1], tf.float32)
        masks.append(mask_tmp - mask_last)
        masks_last = mask_tmp
    mask_tmp = tf.cast(idx_range < split_idx[:, -1:], tf.float32)
    masks.append(1. - mask_tmp)
    x_split_ls = [x * mask for mask in masks]
    # mask_1 = tf.cast(idx_range < split_idx[:, tf.newaxis], tf.float32)
    # mask_2 = 1. - mask_1
    # return x * mask_1, x * mask_2
    return x_split_ls
Пример #2
0
 def with_replacement():
     # sample integers from [0, seq_len)
     indices = (tf.random.uniform(shape=[min_length - 2]) *
                tf.cast(sequence_length, float))
     middle = tf.cast(tf.math.floor(indices), tf.int64)
     return tf.sort(
         tf.concat([[0], middle, [sequence_length - 1]], axis=0))
Пример #3
0
 def with_replacement():
     # sample integers from [episode_start, episode_end)
     indices = tf.random.uniform(shape=[min_length - 2]) * \
       tf.cast(episode_delta_t, float)
     middle = episode_start + tf.cast(tf.math.floor(indices), tf.int64)
     return tf.sort(
         tf.concat([[episode_start], middle, [episode_end]], axis=0))
Пример #4
0
 def random_frame():
   # Used when min_length == 1.
   indices = (
       tf.random.uniform(shape=[min_length]) *
       tf.cast(sequence_length, float))
   middle = tf.cast(tf.math.floor(indices), tf.int64)
   return tf.sort(middle)
Пример #5
0
def clip_eta(eta, ord, eps):
    """
  Helper function to clip the perturbation to epsilon norm ball.
  :param eta: A tensor with the current perturbation.
  :param ord: Order of the norm (mimics Numpy).
              Possible values: np.inf, 1 or 2.
  :param eps: Epsilon, bound of the perturbation.
  """

    # Clipping perturbation eta to self.ord norm ball
    if ord not in [np.inf, 1, 2]:
        raise ValueError('ord must be np.inf, 1, or 2.')
    reduc_ind = list(xrange(1, len(eta.get_shape())))
    avoid_zero_div = 1e-12
    if ord == np.inf:
        eta = clip_by_value(eta, -eps, eps)
    elif ord == 1:
        # Implements a projection algorithm onto the l1-ball from
        # (Duchi et al. 2008) that runs in time O(d*log(d)) where d is the
        # input dimension.
        # Paper link (Duchi et al. 2008): https://dl.acm.org/citation.cfm?id=1390191

        eps = tf.cast(eps, eta.dtype)

        dim = tf.reduce_prod(tf.shape(eta)[1:])
        eta_flat = tf.reshape(eta, (-1, dim))
        abs_eta = tf.abs(eta_flat)

        if 'sort' in dir(tf):
            mu = -tf.sort(-abs_eta, axis=-1)
        else:
            # `tf.sort` is only available in TF 1.13 onwards
            mu = tf.nn.top_k(abs_eta, k=dim, sorted=True)[0]
        cumsums = tf.cumsum(mu, axis=-1)
        js = tf.cast(tf.divide(1, tf.range(1, dim + 1)), eta.dtype)
        t = tf.cast(tf.greater(mu - js * (cumsums - eps), 0), eta.dtype)

        rho = tf.argmax(t * cumsums, axis=-1)
        rho_val = tf.reduce_max(t * cumsums, axis=-1)
        theta = tf.divide(rho_val - eps, tf.cast(1 + rho, eta.dtype))

        eta_sgn = tf.sign(eta_flat)
        eta_proj = eta_sgn * tf.maximum(abs_eta - theta[:, tf.newaxis], 0)
        eta_proj = tf.reshape(eta_proj, tf.shape(eta))

        norm = tf.reduce_sum(tf.abs(eta), reduc_ind)
        eta = tf.where(tf.greater(norm, eps), eta_proj, eta)

    elif ord == 2:
        # avoid_zero_div must go inside sqrt to avoid a divide by zero
        # in the gradient through this operation
        norm = tf.sqrt(
            tf.maximum(avoid_zero_div,
                       reduce_sum(tf.square(eta), reduc_ind, keepdims=True)))
        # We must *clip* to within the norm ball, not *normalize* onto the
        # surface of the ball
        factor = tf.minimum(1., div(eps, norm))
        eta = eta * factor
    return eta
Пример #6
0
 def without_replacement():
     # sample integers from [episode_start, episode_end)
     indices = tf.random.shuffle(
         tf.range(episode_start + 1, episode_end))
     middle = indices[:min_length - 2]
     middle = tf.reshape(middle, [min_length - 2])
     return tf.sort(
         tf.concat([[episode_start], middle, [episode_end]], axis=0))
Пример #7
0
def top_p_logits(logits, p):
    with tf.variable_scope('top_p_logits'):
        logits_sort = tf.sort(logits, direction='DESCENDING')
        probs_sort = tf.nn.softmax(logits_sort)
        probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)
        logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like(logits_sort)*1000) # [batchsize, vocab]
        min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True) # [batchsize, 1]
        return tf.where(
            logits < min_logits,
            tf.ones_like(logits, dtype=logits.dtype) * -1e10,
            logits,
        )
Пример #8
0
def top_p_logits(logits, p):
    """Nucleus sampling"""
    batch, _ = logits.shape.as_list()
    sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
    cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
    indices = tf.stack([
        tf.range(0, batch),
        # number of indices to include
        tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
    ], axis=-1)
    min_values = tf.gather_nd(sorted_logits, indices)
    return tf.where(
        logits < min_values,
        tf.ones_like(logits) * -1e10,
        logits,
    )
Пример #9
0
def padded_where(condition, length):
    """TPU friendly version of tf.where(cond) with fixed length and padding.

  This is a wrapper around tf.where(cond) that returns the coordinates of the
  True elements of cond (case where x and y are None). This version, however,
  returns a fixed length tensor of coordinates, determined by `length`.  If the
  number of True elements in `condition` is less than `length`, then the
  returned tensor is right-padded with zeros. Otherwise, the returned tensor is
  truncated to `length` size.

  Args:
    condition: tf.Tensor of type boolean; any shape.
    length: Length of (last dimension of) the returned tensor.

  Returns:
    Two tensors:
    - a tensor of type int32, with same shape as `condition`, representing
      coordinates of the last dimension of `condition` tensor where values are
      True.
    - a mask tensor of type int32 with 1s in valid indices of the first tensor,
      and 0s for padded indices.
  """
    condition_shape = shape(condition)
    n = condition_shape[-1]

    # Build a tensor that counts indices from 0 to length of condition.
    ixs = tf.broadcast_to(tf.range(n, dtype=tf.int32), condition_shape)

    # Build tensor where True condition values get their index value or
    # n (== len(condition)) otherwise.
    ixs = tf.where(condition, ixs, tf.ones_like(condition, dtype=tf.int32) * n)

    # Sort indices (so that indices for False values == n, will be placed last),
    # and get the desired number of entries, truncating by `length`.
    ixs = tf.sort(ixs)[Ellipsis, 0:length]

    # For first tensor, zero-out values == n. For second tensor, put 1s where
    # values are < n, and 0s where values are == 0.
    return tf.mod(ixs, n), (1 - tf.div(ixs, n))
Пример #10
0
def _single_token_mask(inputs, tgt_len, num_predict, exclude_mask=None):
    """Sample individual tokens as prediction targets."""
    func_mask = tf.equal(inputs, FLAGS.cls_id)
    func_mask = tf.logical_or(func_mask, tf.equal(inputs, FLAGS.sep_id))
    func_mask = tf.logical_or(func_mask, tf.equal(inputs, FLAGS.pad_id))
    if exclude_mask is None:
        exclude_mask = func_mask
    else:
        exclude_mask = tf.logical_or(func_mask, exclude_mask)
    candidate_mask = tf.logical_not(exclude_mask)

    all_indices = tf.range(tgt_len, dtype=tf.int64)
    candidate_indices = tf.boolean_mask(all_indices, candidate_mask)
    masked_pos = tf.random.shuffle(candidate_indices)
    masked_pos = tf.sort(masked_pos[:num_predict])
    target_mask = tf.sparse_to_dense(sparse_indices=masked_pos,
                                     output_shape=[tgt_len],
                                     sparse_values=1.0,
                                     default_value=0.0)
    is_target = tf.cast(target_mask, tf.bool)

    return is_target, target_mask
Пример #11
0
def block_delete_msa(protein, config):
    """Sample MSA by deleting contiguous blocks.

  Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion"

  Arguments:
    protein: batch dict containing the msa
    config: ConfigDict with parameters

  Returns:
    updated protein
  """
    num_seq = shape_helpers.shape_list(protein['msa'])[0]
    block_num_seq = tf.cast(
        tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block),
        tf.int32)

    if config.randomize_num_blocks:
        nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32)
    else:
        nb = config.num_blocks

    del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32)
    del_blocks = del_block_starts[:, None] + tf.range(block_num_seq)
    del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1)
    del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0]

    # Make sure we keep the original sequence
    sparse_diff = tf.sets.difference(
        tf.range(1, num_seq)[None], del_indices[None])
    keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0)
    keep_indices = tf.concat([[0], keep_indices], axis=0)

    for k in _MSA_FEATURE_NAMES:
        if k in protein:
            protein[k] = tf.gather(protein[k], keep_indices)

    return protein
Пример #12
0
def get_doc_rep_with_masked_sent(input_sent_reps_doc,
                                 sent_mask_embedding,
                                 input_mask_doc_level,
                                 batch_size_static=32,
                                 max_masked_sent_per_doc=2,
                                 loop_sent_number_per_doc=32):
  """Get the document representations with masked sentences.

  Args:
      input_sent_reps_doc: float Tensor. The independent sentence embeddings
        without masks for the sentences in the current document. The shape is
        [batch, loop_sent_number_per_doc, hidden].
      sent_mask_embedding: float Tensor. The sentence embedding vector for the
        masked position. The shape is [hidden].
      input_mask_doc_level: int Tensor. The input masks on the document level to
        identify whether a location is a real sentence (mask = 1) or a padded
        sentence (mask = 0). The shape is [batch, loop_sent_number_per_doc].
      batch_size_static: scalar. The static batch size depending on the training
        or the evaluation mode.
      max_masked_sent_per_doc: scalar. The maximum number of masked sentences
        per document.
      loop_sent_number_per_doc: scalar. The number of looped sentences per
        document.

  Returns:
    The document representations with masked sentences and the positions/
    weights for each masked sentences. This masked sentence weight is 1 for the
    sampled real sentence position and 0 for the padded sentence position.
  """
  # We at least mask two sentences to build a candidate sentence pool for
  # negative sentence sampling. We generate the masked_sent_index and
  # masked_sent_weight for each document. Note that we do not add any word
  # or sentence level masks during prediction or inference stage.
  max_masked_sent_per_doc = max(max_masked_sent_per_doc, 2)
  input_sent_reps_doc_list = tf.unstack(
      input_sent_reps_doc, num=batch_size_static)
  real_sent_number_per_doc = tf.unstack(
      tf.reduce_sum(input_mask_doc_level, 1), num=batch_size_static)
  masked_sent_index_list = []
  masked_sent_weight_list = []

  # For each example in the current batch, we randomly sample
  # max_masked_sent_per_doc positions to mask the sentences. For each masked
  # sentence position, the sentence in the current position is the positive
  # example. The other co-masked sentences are the negative examples.
  # The sampled sentence indexes will not be duplicated.
  for batch_i in range(0, batch_size_static):
    # Since everything in TPU must have a fixed shape, here the max sampled
    # sentence index can be as large as loop_sent_number_per_doc. We will
    # generate the corresponding sentence LM weights to reduce the impact
    # on the final masked sentence LM loss following a similar way with the
    # handling of masked word LM loss and masked word LM weights.
    real_sent_number = real_sent_number_per_doc[batch_i]
    sampled_sent_index = tf.slice(
        tf.random_shuffle(tf.range(loop_sent_number_per_doc)), [0],
        [max_masked_sent_per_doc])
    sampled_sent_index = tf.sort(sampled_sent_index)
    masked_sent_index_list.append(sampled_sent_index)
    # Generates the corresponding sampled_sent_weight
    sample_sent_weight = tf.cast(
        tf.less(sampled_sent_index, real_sent_number), tf.float32)
    masked_sent_weight_list.append(sample_sent_weight)

    indices = tf.reshape(sampled_sent_index, [max_masked_sent_per_doc, -1])
    # Duplicates sent_mask_embedding for each masked position.
    updates = tf.reshape(
        tf.tile(
            sent_mask_embedding,
            [max_masked_sent_per_doc],
        ), [max_masked_sent_per_doc, -1])
    input_sent_reps_doc_list[batch_i] = tf.tensor_scatter_update(
        input_sent_reps_doc_list[batch_i], indices, updates)
  # Here masked_sent_index_list is a list a tensors, where each tensor stores
  # the masked sentence positions for each document in the current batch. The
  # shape of masked_sent_index_list is [batch, max_masked_sent_per_doc].
  # Here masked_sent_weight_list is a list a tensors, where each tensor stores
  # the masked sentence weights for each document in the current batch. The
  # shape of masked_sent_weight_list is [batch, max_masked_sent_per_doc].
  return (tf.stack(input_sent_reps_doc_list), tf.stack(masked_sent_index_list),
          tf.stack(masked_sent_weight_list))
Пример #13
0
 def without_replacement():
     # sample integers from [1, seq_len-1)
     indices = tf.random.shuffle(tf.range(1, sequence_length - 1))
     middle = indices[:min_length - 2]
     return tf.sort(
         tf.concat([[0], middle, [sequence_length - 1]], axis=0))
Пример #14
0
  def _get_richer_data(self, fake_data):
    inputs_tf = fake_data.inputs.input_ids
    labels_tf = fake_data.is_fake_tokens
    lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1)
    #retrieve the basic config
    V = self._bert_config.vocab_size
    #sub: 10%, del + ins: 5%
    N = int(self._config.max_predictions_per_seq * self._config.rich_prob)
    B, L = modeling.get_shape_list(inputs_tf)
    nlms = 0
    bilm = None
    if self._config.use_bilm:
      with open(self._config.bilm_file, 'rb') as f:
        bilm = tf.constant(np.load(f), tf.int32)
      _, nlms = modeling.get_shape_list(bilm)
    #make multiple partitions for edit op
    splits_list = []
    for i in range(B):
      one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32)
      one, _ = tf.unique(one)
      one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1),
                    lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0),
                    lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1])))
      splits_list.append(one[:, 2::2])
    splits_tf = tf.concat(splits_list, 0)
    splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1)
    splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1)
    size_splits = splits_up - splits_lo
    #update the inputs and labels giving random insertion and deletion
    new_labels_list = []
    new_inputs_list = []
    for i in range(B):
      inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :])
      labels_splits = tf.split(labels_tf[i, :], size_splits[i, :])
      one_inputs = []
      one_labels = []
      size_split = len(inputs_splits)
      inputs_end = inputs_splits[-1]
      labels_end = labels_splits[-1]
      for j in range(size_split-1):
        inputs = inputs_splits[j]
        labels = labels_splits[j] #label 1 for substistution
        rand_op = random.randint(2, self._config.num_preds - 1) 
        if rand_op == 2: #label 2 for insertion
          if bilm is None: #noise
            insert_tok = tf.random.uniform([1], 1, V, tf.int32)
          else: #2-gram prediction
            insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0)
          is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0])
          inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs)
          labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels)
          inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end)
          labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end)
        elif rand_op == 3: #label 3 for deletion
          labels = tf.concat([labels[:-2], tf.constant([3])], 0)
          inputs = inputs[:-1]
          inputs_end = tf.concat([inputs_end, tf.constant([0])], 0)
          labels_end = tf.concat([labels_end, tf.constant([0])], 0)
        elif rand_op == 4: #label 4 for swapping
          labels = tf.concat([labels[:-1], tf.constant([4])], 0)
          inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0)
        one_labels.append(labels)
        one_inputs.append(inputs)
      one_inputs.append(inputs_end)
      one_labels.append(labels_end)
      one_inputs_tf = tf.concat(one_inputs, 0)
      one_labels_tf = tf.concat(one_labels, 0)
      one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf)
      one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf)
      new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0))
      new_labels_list.append(tf.expand_dims(one_labels_tf, 0))

    new_inputs_tf = tf.concat(new_inputs_list, 0)
    new_labels_tf = tf.concat(new_labels_list, 0)
    new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32)
    updated_inputs = pretrain_data.get_updated_inputs(
        fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask)
    RicherData = collections.namedtuple("RicherData", [
        "inputs", "is_fake_tokens", "sampled_tokens"])
    return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf,
                     sampled_tokens=fake_data.sampled_tokens)