def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.math.log(probs), consistent
def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1
def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def _GetTaskIds(self, source_id): """Look up the correct task_id from the source_id tensor.""" if self.params.file_pattern_task_ids: file_task_ids = tf.constant(self.params.file_pattern_task_ids, dtype=tf.int32) source_id = tf.gather(file_task_ids, source_id) src_task_id = source_id tgt_task_id = source_id if self.params.task_to_src_lang_map: src_lang_ids = tf.constant(self.params.task_to_src_lang_map, dtype=tf.int32) src_task_id = tf.gather(src_lang_ids, src_task_id) if self.params.task_to_tgt_lang_map: tgt_lang_ids = tf.constant(self.params.task_to_tgt_lang_map, dtype=tf.int32) tgt_task_id = tf.gather(tgt_lang_ids, tgt_task_id) return src_task_id, tgt_task_id
def gather(self, values, max_value=None, axis=0, batch_major_state=True): """Returns 'values' gathered at the ids provided to the constructor. Args: values: Values to gather from. max_value: The largest of values. axis: Axis to gather on. Defaults to 0 (rows). batch_major_state: Whether the values to gather from use batch major or not. Defaults to True. For Transformer model, batch_major_state is set to False (time is the major dim). Returns: Gathered values. Raises: Value error: If dtype is not supported. NotImplemented error: if axis is not 0 or 1. """ # Carry out the gather via matmul if the values are floating point or can # be represented exactly in bfloat16 (TPU internal matmul dtype). dtype = values.dtype if dtype in (tf.bfloat16, tf.float32, tf.bool) or _has_bfloat16_repr(max_value): return self._matmul_gather(values, axis=axis, batch_major_state=batch_major_state) elif dtype == tf.int32: # For int32s with a max_value that can't be represented exactly in # floating point, we decompose `values` into parts that can be represented # exactly, gather each part individually, and recombine to get the final # gathered values. max_value = max_value or _MAX_INT32 if max_value <= _MAX_BFLOAT16_INT**2: # Break 'values' into two bfloat16-representable parts. The low part # values are in [-255, 255]. High part values are in [-256, 256]. signs = tf.sign(values) abs_values = tf.abs(values) low_part = signs * tf.bitwise.bitwise_and(abs_values, 0xff) high_part = signs * tf.bitwise.right_shift(abs_values, 8) low_part_gathered = self._matmul_gather( low_part, axis=axis, batch_major_state=batch_major_state) high_part_gathered = self._matmul_gather( high_part, axis=axis, batch_major_state=batch_major_state) return tf.bitwise.left_shift(high_part_gathered, 8) + low_part_gathered else: # For larger magnitude int32s, we could break them up into 3 or 4 byte- # sized matmuls, but regular-old tf.gather() is more efficient. return tf.gather(values, self._ids, axis=axis) else: raise ValueError("Unsupported dtype %s" % values.dtype)
def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0: # For rank > 1 tensors we make use of an efficient matmul based gather # on tpu that takes in account the range of the values. For R1, we # rely on the tf.gather and xla to optimize it efficiently for R1 # layout. if x_in.shape.ndims > 1: if p.batch_major_state: num_hyps = tf.shape(old_hyp_ids)[0] x_out = beam_search_tpu_ops.fast_gather( x_in, old_hyp_ids, num_hyps, max_value=None, batch_major_state=p.batch_major_state) else: # Use corrected indices only here for batch major compute as # key/value caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1 x_out, _ = tf.while_loop( lambda _, t: t <= cur_step, _GatherStep, (x_in, tf.zeros([], tf.int32))) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) histogram('var_hist/' + name, var) histogram('grad_hist/' + name, grad)
def _matmul_gather(self, values, axis=0, batch_major_state=True): """Returns values gathered. Args: values: Values to gather from. axis: Axis to gather on. Defaults to 0 (rows). batch_major_state: Whether the values to gather from use batch major or not. Defaults to True. For Transformer model, batch_major_state is set to False (time is the major dim). Returns: Gathered values. Raises: NotImplemented error if axis is not 0 nor 1. """ dtype = values.dtype if dtype != tf.float32 and dtype != tf.bfloat16: values = tf.cast(values, tf.float32) if axis == 0: if values.shape.rank is not None and values.shape.rank > 2: if not batch_major_state: values = tf.transpose(values, [1, 0, 2]) results = tf.cast( tf.gather(values, tf.cast(self._ids, tf.int32)), dtype) # pylint:disable=g-long-ternary return (tf.transpose(results, [1, 0, 2]) if not batch_major_state else results) # pylint:enable=g-long-ternary else: one_hot_ids = tf.one_hot(self._ids, self._ids_size, dtype=values.dtype) return tf.cast(tf.matmul(one_hot_ids, values), dtype) elif axis == 1: one_hot_ids = tf.one_hot(self._ids, self._ids_size, dtype=values.dtype, axis=0) return tf.cast(tf.matmul(values, one_hot_ids), dtype) else: raise NotImplementedError("Only row/col-wise gather implemented.")
def _Pack(self, batch): """Packs a given batch. Note that this may change the batch size. This function packs the input batch and adds .segment_ids and .segment_pos fields to its `src` and `tgt` fields. Args: batch: a `.NestedMap` of input tensors to be packed. It is modified in place. """ src_actual_seq_len = tf.math.reduce_sum(tf.cast( batch.src.ids_indicator, tf.int32), axis=1) tgt_actual_seq_len = tf.math.reduce_sum(tf.cast( batch.tgt.ids_indicator, tf.int32), axis=1) summary_utils.histogram('source_seq_lengths', src_actual_seq_len) summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len) if not self.params.packing_factor: # Supply segment_ids and segment_pos with no packing. batch.src.segment_ids = batch.src.ids_indicator batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator) batch.tgt.segment_ids = batch.tgt.ids_indicator batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator) return (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids, tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences( src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(), self.params.source_max_length, self.params.target_max_length) uniq_src_indices_in_input = tf.unique( tf.reshape(src_indices_in_input, [-1])).y uniq_tgt_indices_in_input = tf.unique( tf.reshape(tgt_indices_in_input, [-1])).y summary_utils.histogram( 'packed_source_seq_lengths', tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0)) summary_utils.histogram( 'packed_target_seq_lengths', tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0)) # We deferred adding .paddings and use its complement .ids_indicator # exclusively so that we can apply the packing with padding set to 0 for all # fields. def ApplyPackingToSource(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input) return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input) batch.src = batch.src.Transform(ApplyPackingToSource) batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32) batch.src.segment_pos = src_segment_pos def ApplyPackingToTarget(x): if x.dtype == tf.string: return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input) return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input) batch.tgt = batch.tgt.Transform(ApplyPackingToTarget) batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32) batch.tgt.segment_pos = tgt_segment_pos