def split_latents(x, minibatch_size=1, hy_ncut=1): # x: [b, dim] # b = minibatch_size # b = x.get_shape().as_list()[0] if hy_ncut == 0: return [x] b = tf.shape(x)[0] dim = x.get_shape().as_list()[1] split_idx = tf.random.uniform(shape=[b, hy_ncut], maxval=dim + 1, dtype=tf.int32) split_idx = tf.sort(split_idx, axis=-1) idx_range = tf.tile(tf.range(dim)[tf.newaxis, :], [b, 1]) masks = [] mask_last = tf.zeros([b, dim], dtype=tf.float32) for i in range(hy_ncut): mask_tmp = tf.cast(idx_range < split_idx[:, i:i + 1], tf.float32) masks.append(mask_tmp - mask_last) masks_last = mask_tmp mask_tmp = tf.cast(idx_range < split_idx[:, -1:], tf.float32) masks.append(1. - mask_tmp) x_split_ls = [x * mask for mask in masks] # mask_1 = tf.cast(idx_range < split_idx[:, tf.newaxis], tf.float32) # mask_2 = 1. - mask_1 # return x * mask_1, x * mask_2 return x_split_ls
def with_replacement(): # sample integers from [0, seq_len) indices = (tf.random.uniform(shape=[min_length - 2]) * tf.cast(sequence_length, float)) middle = tf.cast(tf.math.floor(indices), tf.int64) return tf.sort( tf.concat([[0], middle, [sequence_length - 1]], axis=0))
def with_replacement(): # sample integers from [episode_start, episode_end) indices = tf.random.uniform(shape=[min_length - 2]) * \ tf.cast(episode_delta_t, float) middle = episode_start + tf.cast(tf.math.floor(indices), tf.int64) return tf.sort( tf.concat([[episode_start], middle, [episode_end]], axis=0))
def random_frame(): # Used when min_length == 1. indices = ( tf.random.uniform(shape=[min_length]) * tf.cast(sequence_length, float)) middle = tf.cast(tf.math.floor(indices), tf.int64) return tf.sort(middle)
def clip_eta(eta, ord, eps): """ Helper function to clip the perturbation to epsilon norm ball. :param eta: A tensor with the current perturbation. :param ord: Order of the norm (mimics Numpy). Possible values: np.inf, 1 or 2. :param eps: Epsilon, bound of the perturbation. """ # Clipping perturbation eta to self.ord norm ball if ord not in [np.inf, 1, 2]: raise ValueError('ord must be np.inf, 1, or 2.') reduc_ind = list(xrange(1, len(eta.get_shape()))) avoid_zero_div = 1e-12 if ord == np.inf: eta = clip_by_value(eta, -eps, eps) elif ord == 1: # Implements a projection algorithm onto the l1-ball from # (Duchi et al. 2008) that runs in time O(d*log(d)) where d is the # input dimension. # Paper link (Duchi et al. 2008): https://dl.acm.org/citation.cfm?id=1390191 eps = tf.cast(eps, eta.dtype) dim = tf.reduce_prod(tf.shape(eta)[1:]) eta_flat = tf.reshape(eta, (-1, dim)) abs_eta = tf.abs(eta_flat) if 'sort' in dir(tf): mu = -tf.sort(-abs_eta, axis=-1) else: # `tf.sort` is only available in TF 1.13 onwards mu = tf.nn.top_k(abs_eta, k=dim, sorted=True)[0] cumsums = tf.cumsum(mu, axis=-1) js = tf.cast(tf.divide(1, tf.range(1, dim + 1)), eta.dtype) t = tf.cast(tf.greater(mu - js * (cumsums - eps), 0), eta.dtype) rho = tf.argmax(t * cumsums, axis=-1) rho_val = tf.reduce_max(t * cumsums, axis=-1) theta = tf.divide(rho_val - eps, tf.cast(1 + rho, eta.dtype)) eta_sgn = tf.sign(eta_flat) eta_proj = eta_sgn * tf.maximum(abs_eta - theta[:, tf.newaxis], 0) eta_proj = tf.reshape(eta_proj, tf.shape(eta)) norm = tf.reduce_sum(tf.abs(eta), reduc_ind) eta = tf.where(tf.greater(norm, eps), eta_proj, eta) elif ord == 2: # avoid_zero_div must go inside sqrt to avoid a divide by zero # in the gradient through this operation norm = tf.sqrt( tf.maximum(avoid_zero_div, reduce_sum(tf.square(eta), reduc_ind, keepdims=True))) # We must *clip* to within the norm ball, not *normalize* onto the # surface of the ball factor = tf.minimum(1., div(eps, norm)) eta = eta * factor return eta
def without_replacement(): # sample integers from [episode_start, episode_end) indices = tf.random.shuffle( tf.range(episode_start + 1, episode_end)) middle = indices[:min_length - 2] middle = tf.reshape(middle, [min_length - 2]) return tf.sort( tf.concat([[episode_start], middle, [episode_end]], axis=0))
def top_p_logits(logits, p): with tf.variable_scope('top_p_logits'): logits_sort = tf.sort(logits, direction='DESCENDING') probs_sort = tf.nn.softmax(logits_sort) probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True) logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like(logits_sort)*1000) # [batchsize, vocab] min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True) # [batchsize, 1] return tf.where( logits < min_logits, tf.ones_like(logits, dtype=logits.dtype) * -1e10, logits, )
def top_p_logits(logits, p): """Nucleus sampling""" batch, _ = logits.shape.as_list() sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) indices = tf.stack([ tf.range(0, batch), # number of indices to include tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0), ], axis=-1) min_values = tf.gather_nd(sorted_logits, indices) return tf.where( logits < min_values, tf.ones_like(logits) * -1e10, logits, )
def padded_where(condition, length): """TPU friendly version of tf.where(cond) with fixed length and padding. This is a wrapper around tf.where(cond) that returns the coordinates of the True elements of cond (case where x and y are None). This version, however, returns a fixed length tensor of coordinates, determined by `length`. If the number of True elements in `condition` is less than `length`, then the returned tensor is right-padded with zeros. Otherwise, the returned tensor is truncated to `length` size. Args: condition: tf.Tensor of type boolean; any shape. length: Length of (last dimension of) the returned tensor. Returns: Two tensors: - a tensor of type int32, with same shape as `condition`, representing coordinates of the last dimension of `condition` tensor where values are True. - a mask tensor of type int32 with 1s in valid indices of the first tensor, and 0s for padded indices. """ condition_shape = shape(condition) n = condition_shape[-1] # Build a tensor that counts indices from 0 to length of condition. ixs = tf.broadcast_to(tf.range(n, dtype=tf.int32), condition_shape) # Build tensor where True condition values get their index value or # n (== len(condition)) otherwise. ixs = tf.where(condition, ixs, tf.ones_like(condition, dtype=tf.int32) * n) # Sort indices (so that indices for False values == n, will be placed last), # and get the desired number of entries, truncating by `length`. ixs = tf.sort(ixs)[Ellipsis, 0:length] # For first tensor, zero-out values == n. For second tensor, put 1s where # values are < n, and 0s where values are == 0. return tf.mod(ixs, n), (1 - tf.div(ixs, n))
def _single_token_mask(inputs, tgt_len, num_predict, exclude_mask=None): """Sample individual tokens as prediction targets.""" func_mask = tf.equal(inputs, FLAGS.cls_id) func_mask = tf.logical_or(func_mask, tf.equal(inputs, FLAGS.sep_id)) func_mask = tf.logical_or(func_mask, tf.equal(inputs, FLAGS.pad_id)) if exclude_mask is None: exclude_mask = func_mask else: exclude_mask = tf.logical_or(func_mask, exclude_mask) candidate_mask = tf.logical_not(exclude_mask) all_indices = tf.range(tgt_len, dtype=tf.int64) candidate_indices = tf.boolean_mask(all_indices, candidate_mask) masked_pos = tf.random.shuffle(candidate_indices) masked_pos = tf.sort(masked_pos[:num_predict]) target_mask = tf.sparse_to_dense(sparse_indices=masked_pos, output_shape=[tgt_len], sparse_values=1.0, default_value=0.0) is_target = tf.cast(target_mask, tf.bool) return is_target, target_mask
def block_delete_msa(protein, config): """Sample MSA by deleting contiguous blocks. Jumper et al. (2021) Suppl. Alg. 1 "MSABlockDeletion" Arguments: protein: batch dict containing the msa config: ConfigDict with parameters Returns: updated protein """ num_seq = shape_helpers.shape_list(protein['msa'])[0] block_num_seq = tf.cast( tf.floor(tf.cast(num_seq, tf.float32) * config.msa_fraction_per_block), tf.int32) if config.randomize_num_blocks: nb = tf.random.uniform([], 0, config.num_blocks + 1, dtype=tf.int32) else: nb = config.num_blocks del_block_starts = tf.random.uniform([nb], 0, num_seq, dtype=tf.int32) del_blocks = del_block_starts[:, None] + tf.range(block_num_seq) del_blocks = tf.clip_by_value(del_blocks, 0, num_seq - 1) del_indices = tf.unique(tf.sort(tf.reshape(del_blocks, [-1])))[0] # Make sure we keep the original sequence sparse_diff = tf.sets.difference( tf.range(1, num_seq)[None], del_indices[None]) keep_indices = tf.squeeze(tf.sparse.to_dense(sparse_diff), 0) keep_indices = tf.concat([[0], keep_indices], axis=0) for k in _MSA_FEATURE_NAMES: if k in protein: protein[k] = tf.gather(protein[k], keep_indices) return protein
def get_doc_rep_with_masked_sent(input_sent_reps_doc, sent_mask_embedding, input_mask_doc_level, batch_size_static=32, max_masked_sent_per_doc=2, loop_sent_number_per_doc=32): """Get the document representations with masked sentences. Args: input_sent_reps_doc: float Tensor. The independent sentence embeddings without masks for the sentences in the current document. The shape is [batch, loop_sent_number_per_doc, hidden]. sent_mask_embedding: float Tensor. The sentence embedding vector for the masked position. The shape is [hidden]. input_mask_doc_level: int Tensor. The input masks on the document level to identify whether a location is a real sentence (mask = 1) or a padded sentence (mask = 0). The shape is [batch, loop_sent_number_per_doc]. batch_size_static: scalar. The static batch size depending on the training or the evaluation mode. max_masked_sent_per_doc: scalar. The maximum number of masked sentences per document. loop_sent_number_per_doc: scalar. The number of looped sentences per document. Returns: The document representations with masked sentences and the positions/ weights for each masked sentences. This masked sentence weight is 1 for the sampled real sentence position and 0 for the padded sentence position. """ # We at least mask two sentences to build a candidate sentence pool for # negative sentence sampling. We generate the masked_sent_index and # masked_sent_weight for each document. Note that we do not add any word # or sentence level masks during prediction or inference stage. max_masked_sent_per_doc = max(max_masked_sent_per_doc, 2) input_sent_reps_doc_list = tf.unstack( input_sent_reps_doc, num=batch_size_static) real_sent_number_per_doc = tf.unstack( tf.reduce_sum(input_mask_doc_level, 1), num=batch_size_static) masked_sent_index_list = [] masked_sent_weight_list = [] # For each example in the current batch, we randomly sample # max_masked_sent_per_doc positions to mask the sentences. For each masked # sentence position, the sentence in the current position is the positive # example. The other co-masked sentences are the negative examples. # The sampled sentence indexes will not be duplicated. for batch_i in range(0, batch_size_static): # Since everything in TPU must have a fixed shape, here the max sampled # sentence index can be as large as loop_sent_number_per_doc. We will # generate the corresponding sentence LM weights to reduce the impact # on the final masked sentence LM loss following a similar way with the # handling of masked word LM loss and masked word LM weights. real_sent_number = real_sent_number_per_doc[batch_i] sampled_sent_index = tf.slice( tf.random_shuffle(tf.range(loop_sent_number_per_doc)), [0], [max_masked_sent_per_doc]) sampled_sent_index = tf.sort(sampled_sent_index) masked_sent_index_list.append(sampled_sent_index) # Generates the corresponding sampled_sent_weight sample_sent_weight = tf.cast( tf.less(sampled_sent_index, real_sent_number), tf.float32) masked_sent_weight_list.append(sample_sent_weight) indices = tf.reshape(sampled_sent_index, [max_masked_sent_per_doc, -1]) # Duplicates sent_mask_embedding for each masked position. updates = tf.reshape( tf.tile( sent_mask_embedding, [max_masked_sent_per_doc], ), [max_masked_sent_per_doc, -1]) input_sent_reps_doc_list[batch_i] = tf.tensor_scatter_update( input_sent_reps_doc_list[batch_i], indices, updates) # Here masked_sent_index_list is a list a tensors, where each tensor stores # the masked sentence positions for each document in the current batch. The # shape of masked_sent_index_list is [batch, max_masked_sent_per_doc]. # Here masked_sent_weight_list is a list a tensors, where each tensor stores # the masked sentence weights for each document in the current batch. The # shape of masked_sent_weight_list is [batch, max_masked_sent_per_doc]. return (tf.stack(input_sent_reps_doc_list), tf.stack(masked_sent_index_list), tf.stack(masked_sent_weight_list))
def without_replacement(): # sample integers from [1, seq_len-1) indices = tf.random.shuffle(tf.range(1, sequence_length - 1)) middle = indices[:min_length - 2] return tf.sort( tf.concat([[0], middle, [sequence_length - 1]], axis=0))
def _get_richer_data(self, fake_data): inputs_tf = fake_data.inputs.input_ids labels_tf = fake_data.is_fake_tokens lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1) #retrieve the basic config V = self._bert_config.vocab_size #sub: 10%, del + ins: 5% N = int(self._config.max_predictions_per_seq * self._config.rich_prob) B, L = modeling.get_shape_list(inputs_tf) nlms = 0 bilm = None if self._config.use_bilm: with open(self._config.bilm_file, 'rb') as f: bilm = tf.constant(np.load(f), tf.int32) _, nlms = modeling.get_shape_list(bilm) #make multiple partitions for edit op splits_list = [] for i in range(B): one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32) one, _ = tf.unique(one) one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1), lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0), lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1]))) splits_list.append(one[:, 2::2]) splits_tf = tf.concat(splits_list, 0) splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1) splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1) size_splits = splits_up - splits_lo #update the inputs and labels giving random insertion and deletion new_labels_list = [] new_inputs_list = [] for i in range(B): inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :]) labels_splits = tf.split(labels_tf[i, :], size_splits[i, :]) one_inputs = [] one_labels = [] size_split = len(inputs_splits) inputs_end = inputs_splits[-1] labels_end = labels_splits[-1] for j in range(size_split-1): inputs = inputs_splits[j] labels = labels_splits[j] #label 1 for substistution rand_op = random.randint(2, self._config.num_preds - 1) if rand_op == 2: #label 2 for insertion if bilm is None: #noise insert_tok = tf.random.uniform([1], 1, V, tf.int32) else: #2-gram prediction insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0) is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0]) inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs) labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels) inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end) labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end) elif rand_op == 3: #label 3 for deletion labels = tf.concat([labels[:-2], tf.constant([3])], 0) inputs = inputs[:-1] inputs_end = tf.concat([inputs_end, tf.constant([0])], 0) labels_end = tf.concat([labels_end, tf.constant([0])], 0) elif rand_op == 4: #label 4 for swapping labels = tf.concat([labels[:-1], tf.constant([4])], 0) inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0) one_labels.append(labels) one_inputs.append(inputs) one_inputs.append(inputs_end) one_labels.append(labels_end) one_inputs_tf = tf.concat(one_inputs, 0) one_labels_tf = tf.concat(one_labels, 0) one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf) one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf) new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0)) new_labels_list.append(tf.expand_dims(one_labels_tf, 0)) new_inputs_tf = tf.concat(new_inputs_list, 0) new_labels_tf = tf.concat(new_labels_list, 0) new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32) updated_inputs = pretrain_data.get_updated_inputs( fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask) RicherData = collections.namedtuple("RicherData", [ "inputs", "is_fake_tokens", "sampled_tokens"]) return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf, sampled_tokens=fake_data.sampled_tokens)