def upsampling_tpu_compatible(data, scale): """Nearest neighbor upsampling TPU-compatible implementation. This implementation is TPU compatible as opposed to tf.image.resize_nearest_neighbor(). Args: data: A 4D float32 tensor of shape [batch, height, width, channels]. scale: An integer multiple to scale resolution of input data. Returns: A 4D float32 tensor of shape [batch, height*scale, width*scale, channels]. """ with tf.name_scope('upsampling_tpu_compatible'): if data.get_shape().is_fully_defined(): bs, height, width, _ = [s.value for s in data.get_shape()] else: shape = tf.shape(data) bs, height, width = shape[0], shape[1], shape[2] channels = data.get_shape().as_list()[3] # Use reshape to quickly upsample the input. The nearest pixel is selected # implicitly via broadcasting. data = tf.reshape(data, [bs, height, 1, width, 1, channels]) * tf.ones( [1, 1, scale, 1, scale, 1], dtype=data.dtype) return tf.reshape(data, [bs, height * scale, width * scale, channels])
def pad_to_fixed_size(data, pad_value, output_shape): """Pad data to a fixed length at the first dimension. Args: data: Tensor to be padded to output_shape. pad_value: A constant value assigned to the paddings. output_shape: The output shape of a 2D tensor. Returns: The Padded tensor with output_shape [max_num_instances, dimension]. """ max_num_instances = output_shape[0] dimension = output_shape[1] data = tf.reshape(data, [-1, dimension]) num_instances = tf.shape(data)[0] assert_length = tf.Assert(tf.less_equal(num_instances, max_num_instances), [num_instances]) with tf.control_dependencies([assert_length]): pad_length = max_num_instances - num_instances paddings = pad_value * tf.ones([pad_length, dimension]) padded_data = tf.concat([data, paddings], axis=0) padded_data = tf.reshape(padded_data, output_shape) return padded_data
def parser(record): """function used to parse tfrecord.""" record_spec = { "input": tf.FixedLenFeature([seq_len], tf.int64), "target": tf.FixedLenFeature([seq_len], tf.int64), "seg_id": tf.FixedLenFeature([seq_len], tf.int64), "label": tf.FixedLenFeature([1], tf.int64), "is_masked": tf.FixedLenFeature([seq_len], tf.int64), } # retrieve serialized example example = tf.parse_single_example( serialized=record, features=record_spec) inputs = example.pop("input") target = example.pop("target") is_masked = tf.cast(example.pop("is_masked"), tf.bool) non_reuse_len = seq_len - reuse_len assert perm_size <= reuse_len and perm_size <= non_reuse_len perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm( inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len) perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm( inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len) perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1) perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1) perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0) target = tf.concat([target_0, target_1], axis=0) target_mask = tf.concat([target_mask_0, target_mask_1], axis=0) input_k = tf.concat([input_k_0, input_k_1], axis=0) input_q = tf.concat([input_q_0, input_q_1], axis=0) if num_predict is not None: indices = tf.range(seq_len, dtype=tf.int64) bool_target_mask = tf.cast(target_mask, tf.bool) indices = tf.boolean_mask(indices, bool_target_mask) ##### extra padding due to CLS/SEP introduced after prepro actual_num_predict = tf.shape(indices)[0] pad_len = num_predict - actual_num_predict ##### target_mapping target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32) paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype) target_mapping = tf.concat([target_mapping, paddings], axis=0) example["target_mapping"] = tf.reshape(target_mapping, [num_predict, seq_len]) ##### target target = tf.boolean_mask(target, bool_target_mask) paddings = tf.zeros([pad_len], dtype=target.dtype) target = tf.concat([target, paddings], axis=0) example["target"] = tf.reshape(target, [num_predict]) ##### target mask target_mask = tf.concat( [tf.ones([actual_num_predict], dtype=tf.float32), tf.zeros([pad_len], dtype=tf.float32)], axis=0) example["target_mask"] = tf.reshape(target_mask, [num_predict]) else: example["target"] = tf.reshape(target, [seq_len]) example["target_mask"] = tf.reshape(target_mask, [seq_len]) # reshape back to fixed shape example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len]) example["input_k"] = tf.reshape(input_k, [seq_len]) example["input_q"] = tf.reshape(input_q, [seq_len]) _convert_example(example, use_bfloat16) for k, v in example.items(): tf.logging.info("%s: %s", k, v) return example
def _local_perm(inputs, targets, is_masked, perm_size, seq_len): """ Sample a permutation of the factorization order, and create an attention mask accordingly. Args: inputs: int64 Tensor in shape [seq_len], input ids. targets: int64 Tensor in shape [seq_len], target ids. is_masked: bool Tensor in shape [seq_len]. True means being selected for partial prediction. perm_size: the length of longest permutation. Could be set to be reuse_len. Should not be larger than reuse_len or there will be data leaks. seq_len: int, sequence length. """ # Generate permutation indices index = tf.range(seq_len, dtype=tf.int64) index = tf.transpose(tf.reshape(index, [-1, perm_size])) index = tf.random_shuffle(index) index = tf.reshape(tf.transpose(index), [-1]) # `perm_mask` and `target_mask` # non-functional tokens non_func_tokens = tf.logical_not(tf.logical_or( tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID))) non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens) masked_or_func_tokens = tf.logical_not(non_mask_tokens) # Set the permutation indices of non-masked (& non-funcional) tokens to the # smallest index (-1): # (1) they can be seen by all other positions # (2) they cannot see masked positions, so there won"t be information leak smallest_index = -tf.ones([seq_len], dtype=tf.int64) rev_index = tf.where(non_mask_tokens, smallest_index, index) # Create `target_mask`: non-funcional and maksed tokens # 1: use mask as input and have loss # 0: use token (or [SEP], [CLS]) as input and do not have loss target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens) target_mask = tf.cast(target_tokens, tf.float32) # Create `perm_mask` # `target_tokens` cannot see themselves self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1) # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens) # 0: can attend if i > j or j is non-masked perm_mask = tf.logical_and( self_rev_index[:, None] <= rev_index[None, :], masked_or_func_tokens) perm_mask = tf.cast(perm_mask, tf.float32) # new target: [next token] for LM and [curr token] (self) for PLM new_targets = tf.concat([inputs[0: 1], targets[: -1]], axis=0) # construct inputs_k inputs_k = inputs # construct inputs_q inputs_q = target_mask return perm_mask, new_targets, target_mask, inputs_k, inputs_q