def _TransposeAttentions(x): return tf.transpose(x, [1, 0, 2])
def FProp(self, theta, input_batch, interpolation_batch=None, lambdas=None): # pyformat: disable """Interpolates source ids in input_batch and interpolation_batch. Refer to Eq. (4) in paper https://arxiv.org/abs/2106.04060. It is a standard Transformer Encoder if interpolation_batch != None. Args: theta: A `.NestedMap` object containing weights values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. interpolation_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. - embs: Embeddings of ids. lambdas: A pair of tensors to combine embeddings of ids in input_batch and interpolation_batch. Returns: A NestedMap of - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ # pyformat: enable p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) max_seq_length = None if (not py_utils.use_tpu() and FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, :max_seq_length] src_segment_pos = input_batch.segment_pos[:, :max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup(theta.softmax, tf.reshape(input_ids, [-1])) if interpolation_batch is not None: other_input_ids = interpolation_batch.ids if not p.shared_emb: other_input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(other_input_ids, [-1])) else: other_input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(other_input_ids, [-1])) lambdas = [tf.expand_dims(a, -1) for a in lambdas] if 'embs' in input_batch and input_batch.embs is not None: input_embs = input_batch.embs if 'embs' in interpolation_batch and interpolation_batch.embs is not None: other_input_embs = interpolation_batch.embs else: input_embs = tf.reshape( input_embs, [-1, tf.shape(input_ids)[1], p.token_emb.embedding_dim]) other_input_embs = tf.reshape( other_input_embs, [-1, tf.shape(other_input_ids)[1], p.token_emb.embedding_dim]) input_embs = lambdas[0] * input_embs + lambdas[1] * other_input_embs paddings = paddings + interpolation_batch.paddings - 1.0 paddings = tf.clip_by_value(paddings, 0.0, 1.0) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) orig_input_embs = input_embs if p.task_emb: if interpolation_batch is None: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) else: task_embs = self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) other_task_embs = self.task_emb.EmbLookup( theta.task_emb, interpolation_batch.task_ids) task_embs = lambdas[0] * task_embs + lambdas[1] * other_task_embs input_embs += task_embs if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp(theta.position_emb, max_time) position_embs = tf.reshape(position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where( tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap( encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.transpose(paddings) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _InferenceSubgraph_Default(self): """Default inference subgraph. Returns: (fetches, feeds), with: - fetches: A dictionary of fetches, containing: - log_pplx_per_token: A matrix of shape [batch, time]. [i, j] is i-th input text's j-th token's log prob. - paddings: A matrix of shape [batch, time]. The padding mask. - log_pplx_per_sample: A vector of shape [batch]. [i] is i-th input text's log prob. - num_oovs_per_sample: A vector of shape [batch] counting the total number of out-of-vocabulary tokens in each input. - tokens_from_labels: A vector of shape [batch] returning the predicted tokens as a sequence after mapping them back to strings from ids using the vocabulary. - ids: A matrix of shape [batch, time]. [i, j] is i-th input text's j-th token's id. - feeds: A dictionary of feeds, containing: - text: A placeholder for a vector of strings. """ text = tf.placeholder(tf.string, shape=[None]) # [batch, time] ids, labels, paddings = self.input_generator.StringsToIds(text) lengths = tf.reduce_sum(tf.to_int32(1 - paddings), axis=1) tokens_from_labels = self.input_generator.IdsToStrings(labels, lengths) oovs = tf.equal(labels, self.input_generator.tokenizer.unk_id) num_oovs_per_sample = tf.to_int32( tf.round(tf.reduce_sum(tf.to_float(oovs) * (1 - paddings), axis=1))) # [time, batch] ids, paddings, labels, weights = self._TrimIfPossibleThenTranspose( ids, paddings, labels, 1.0 - paddings) batch_size = tf.shape(ids)[1] xent_output, _ = self.lm.FPropDefaultTheta( inputs=ids, paddings=paddings, state0=self.lm.zero_state(self.theta.lm, batch_size), labels=py_utils.NestedMap(class_ids=labels, class_weights=weights)) per_example_xent = py_utils.HasShape(xent_output.per_example_xent, tf.shape(ids)) log_pplx_per_sample = tf.reduce_sum(per_example_xent * (1 - paddings), axis=0) fetches = { 'log_pplx_per_token': # [batch, time] tf.transpose(per_example_xent), 'paddings': # [batch, time] tf.transpose(paddings), 'lengths': # [batch] lengths, 'log_pplx_per_sample': # [batch] log_pplx_per_sample, 'num_oovs_per_sample': # [batch], int32 num_oovs_per_sample, 'tokens_from_labels': # [batch], string tokens_from_labels, 'ids': # [batch, time], int32 ids } feeds = {'text': text} return fetches, feeds
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def MaxPool3D(points, point_features, pooling_idx, closest_idx): """Apply max pooling to a point cloud with computed sampling indices. sampled_idx and closest_idx are the outputs of a sampler such as FurthestPointSampler. The pooling operation results in a point cloud with fewer points, where the pooled points are specified by pooling_idx. Each element of pooling_idx contains an integer in the range [0, P1) containing the index of the point in points/points_features. Max pooling is performed by assigning each point to its closest pooled point, and then taking a max over the features of points assigned. We assume that this mapping is provided by closest_idx, where each element should contain an integer in the range [0, P2) containing the index of the pooled point that each point is assigned to. Note: This logic for pooling assumes that there will be at least one value > 0 per sampled region for each feature, otherwise it will return 0. Additionally, it does a reduce over a masked version of the features, so mean and min would not work without a change in the logic. Args: points: a floating point tf.Tensor with shape [N, P1, 3] point_features: a floating point tf.Tensor with shape [N, P1, C] pooling_idx: A tf.int32 tf.Tensor of shape [N, P2] with the index of which points we want to keep. Each value should be in the range [0, P1]. closest_idx: A tf.int32 tf.Tensor of shape [N, P1] representing which sampled point is closest to each original point. Each value should be in the range of [0, P2]. Returns: A tuple of tf.Tensors (pooled_points, pooled_features). pooled_points has shape [N, P2, 3] representing the locations of each selected point. P2 corresponds to num_pooled_points. pooled_features has shape [N, P2, C] representing the pooled features at each point. """ batch_size, num_points = py_utils.GetShape(points, 2) point_features = py_utils.HasShape(point_features, [batch_size, num_points, -1]) pooling_idx = py_utils.HasShape(pooling_idx, [batch_size, -1]) _, num_output_points = py_utils.GetShape(pooling_idx) _, _, feature_dims = py_utils.GetShape(point_features, 3) # Gather new point locations. pooled_points = tf.array_ops.batch_gather(points, pooling_idx) mask = tf.one_hot(closest_idx, num_output_points) # [N, P1, P2] mask = tf.transpose(mask, [2, 0, 1]) # [P2, N, P1] def _PartialPoolFeaturesFn(partial_mask): partial_mask = tf.tile( tf.reshape(partial_mask, [batch_size, num_points, 1]), [1, 1, feature_dims]) # Note: This method of pooling assumes there will be a value > 0 # And will only work with max under this condition. return tf.reduce_max(partial_mask * point_features, axis=1) # Performing a map_fn over the pooled points is more memory efficient. pooled_point_features = tf.map_fn(_PartialPoolFeaturesFn, mask) # [P2, N, P1] pooled_point_features = tf.transpose(pooled_point_features, [1, 0, 2]) return pooled_points, pooled_point_features
def FProp(self, theta, source_input, source_paddings, target_input=None, target_paddings=None, source_segment_id=None, target_segment_id=None, labels=None, label_weights=None, source_pos_id=None, target_pos_id=None, source_task_id=None, target_task_id=None): """Transforms source sequence of Tensors with Transformers layers. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_input: A sequence of ints indicating source input ids of [time, batch] shape or [batch, time] if batch_dim is 0. source_paddings: A sequence of 0s and 1s indicating input paddings of [time, batch] shape or [batch, time] if batch_dim is 0. target_input: A sequence of ints indicating target input ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_paddings: [target_time, target_batch] or [target_batch, target_time] if batch_dim is 0. source_segment_id: A sequence of ints indicating source segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. target_segment_id: A sequence of ints indicating target segment ids of [time, batch] shape or [batch, time] if batch_dim is 0. labels: A sequence of ints indicating label ids of [time, batch] shape, or [batch, time] if batch_dim is 0. label_weights: A sequence of floats indicates label weights of [time, batch] shape, or [batch, time] if batch_dim is 0. source_pos_id: A sequence of ints indicating source position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_pos_id: A sequence of ints indicating target position ids of [time, batch] shape, or [batch, time] if batch_dim is 0. source_task_id: A sequence of ints indicating source task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. target_task_id: A sequence of ints indicating target task ids of [time, batch] shape, or [batch, time] if batch_dim is 0. Returns: transformer_output with shape [time, batch, dim] or [batch, time, dim] if batch_dim is 0. """ p = self.params if p.num_decoder_layers > 0: assert target_input is not None assert target_paddings is not None if p.packed_input: assert source_segment_id is not None, ( 'Need to specify src_segment_id if packed input is supported.') assert source_pos_id is not None, ( 'Need to specify src_pos_id for packed input and embeddings.') logits = super(GPipeTransformerStack, self).FProp(theta, source_input, source_paddings, target_input, target_paddings, source_segment_id, target_segment_id, source_pos_id, target_pos_id, source_task_id, target_task_id) if not p.softmax_tpl: return logits label_weights = tf.reshape(label_weights, [-1]) target_probs = None if p.label_smoothing: if p.batch_dim: # Time-major target_probs = tf.transpose( self.smoother.FProp( theta.smoother, tf.transpose(target_paddings), tf.transpose(labels), target_ids=None), [1, 0, 2]) else: target_probs = self.smoother.FProp( theta.smoother, target_paddings, labels, target_ids=None) target_probs = tf.reshape(target_probs, [-1, p.softmax_tpl.num_classes]) reshaped_logits = tf.reshape(logits, [-1, p.softmax_tpl.num_classes]) tgt_labels = tf.reshape(labels, [-1]) num_splits = len(p.splits) softmax = self.children['cell_{}'.format(num_splits - 1)].softmax softmax_theta = theta['cell_{}'.format(num_splits - 1)].softmax per_example_xent, _ = softmax.XentLossFromLogits( softmax_theta, reshaped_logits, class_weights=tf.reshape(label_weights, [-1]), class_ids=tgt_labels, class_probabilities=target_probs) xent_shape = tf.shape(logits)[:2] per_example_xent = tf.reshape(per_example_xent, xent_shape) return per_example_xent, logits
def _testDecoderFPropFloatHelper(self, func_inline=False, num_decoder_layers=1, target_seq_len=5, residual_start=0): """Computes decoder from params and computes loss with random inputs.""" cluster = cluster_factory.ForTestingWorker(add_summary=True) config = tf.ConfigProto(graph_options=tf.GraphOptions( optimizer_options=tf.OptimizerOptions( do_function_inlining=func_inline))) with cluster, self.session(graph=tf.Graph(), use_gpu=False, config=config) as sess: tf.set_random_seed(8372749040) vn_config = py_utils.VariationalNoiseParams(None, False, False) p = self._DecoderParams(vn_config) p.rnn_layers = num_decoder_layers p.residual_start = residual_start p.target_seq_len = target_seq_len dec = p.Instantiate() src_seq_len = 5 src_enc = tf.random_normal([src_seq_len, 2, 8], seed=9283748) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=tf.float32) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 1]], dtype=tf.float32)) target_transcripts = tf.constant( ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) metrics = dec.FPropDefaultTheta(encoder_outputs, targets).metrics loss = metrics['loss'][0] correct_predicts = metrics['fraction_of_correct_next_step_preds'][ 0] summaries = tf.summary.merge( tf.get_collection(tf.GraphKeys.SUMMARIES)) tf.global_variables_initializer().run() loss_v, _ = sess.run([loss, correct_predicts]) summaries.eval() return loss_v
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback, init_step_ids=None): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. init_step_ids: [batch], optional init step ids, default to SOS. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 assert p.top_k >= 0 assert p.num_hyps_per_beam >= 1 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, p.num_hyps_per_beam) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=init_step_ids if init_step_ids is not None else tf.fill( [batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) if p.use_recurrent: inputs = py_utils.NestedMap( dummy=tf.zeros([p.target_seq_len, batch])) else: inputs = py_utils.NestedMap( ids=tf.TensorArray(dtype=tf.int32, size=p.target_seq_len), logits=tf.TensorArray(dtype=bs_result.log_probs.dtype, size=p.target_seq_len), ) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" if p.use_recurrent: del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( decoder_theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=p.num_hyps_per_beam) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs if p.top_k > 0: topk_logits, topk_ids = tf.math.top_k(state1.logits, k=p.top_k) sample_logits = tf.nn.log_softmax( topk_logits) if p.top_k_renormalize else topk_logits else: sample_logits = state1.logits # Sample ids from logits. [batch]. ids = tf.reshape( tf.random.stateless_categorical( sample_logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) state1.ids = tf.gather(topk_ids, ids, axis=1, batch_dims=1) if p.top_k > 0 else ids if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( decoder_theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) if p.use_recurrent: return state1, py_utils.NestedMap() else: inputs.ids = inputs.ids.write(state0.timestep, state1.ids) inputs.logits = inputs.logits.write(state0.timestep, state1.logits) return (recurrent_theta, state1, inputs) if p.use_recurrent: def StopFn(t, theta, state): del t, theta # Unused: this stop function only uses the state ids. return tf.equal(state.ids, p.target_eos_id) else: def StopFn(recurrent_theta, state, inputs): del recurrent_theta, inputs return tf.logical_not( tf.reduce_all(tf.equal(state.ids, p.target_eos_id))) if p.use_stop_fn: stop_fn = StopFn else: stop_fn = None if p.use_recurrent: accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, stop_fn=stop_fn, allow_implicit_capture=True) else: loop_vars = (recurrent_theta, recurrent_state0, inputs) (_, _, accumulated_states) = tf.while_loop( StopFn, Step, loop_vars=loop_vars, shape_invariants=_GetShapes(loop_vars, none_shapes=True), back_prop=False, maximum_iterations=p.target_seq_len) accumulated_states.ids = accumulated_states.ids.stack() accumulated_states.logits = accumulated_states.logits.stack() result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def expand_tensor(tensor, block_dims): """Expands a 2D tensor by replicating the tensor values. This is equivalent to the kronecker product of the tensor and a matrix of ones of size block_dims. Example:: tensor = [[1,2] [3,4]] block_dims = [2,2] result = [[1 1 2 2] [1 1 2 2] [3 3 4 4] [3 3 4 4]] Args: tensor: A 2D tensor that needs to be expanded. block_dims: List of integers specifying the expansion factor. Returns: The expanded tensor Raises: ValueError: if tensor is not rank-2 or block_dims is does not have 2 elements. """ if tensor.get_shape().ndims != 2: raise ValueError('Input tensor must be rank 2') if len(block_dims) != 2: raise ValueError('block_dims must have 2 elements') block_height, block_width = block_dims def _tile_rows(tensor, multiple): """Create a new tensor by tiling the tensor along rows.""" return tf.tile(tensor, [multiple, 1]) def _generate_indices(num_rows, block_dim): indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32) for k in range(block_dim): for r in range(num_rows): indices[k * num_rows + r] = r * block_dim + k return indices def _replicate_rows(tensor, multiple): tensor_shape = tensor.shape.as_list() expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]] indices = tf.constant(_generate_indices(tensor_shape[0], multiple)) return tf.scatter_nd(indices, _tile_rows(tensor, multiple), expanded_shape) expanded_tensor = tensor # Expand rows by factor block_height. if block_height > 1: expanded_tensor = _replicate_rows(tensor, block_height) # Transpose and expand by factor block_width. Transpose the result. if block_width > 1: expanded_tensor = tf.transpose( _replicate_rows(tf.transpose(expanded_tensor), block_width)) return expanded_tensor
def flat_beam_search(batch_size, beam_size, max_steps, dec_callback, dec_state, bos_id=1, eos_id=2, length_norm_alpha=0.8, beam_gap=3.0, top_k_fn=tf.math.top_k, prefix=None, prefix_len=None, fprop_dtype=tf.float32, ext_size=0, nbest_size=None, debug=True): """Flat beam search. Args: batch_size: batch size beam_size: beam size limit in number of hyps max_steps: max steps dec_callback: decoder callback (see above) dec_state: decoder state bos_id: <s> token id eos_id: </s> token id length_norm_alpha: length normalization parameter beam_gap: early stopping threshold; None to disable top_k_fn: top_k function to call prefix: (optional) int32 tensor [batch_size, prefix_max] prefix_len: (optional) int32 tensor [batch_size] fprop_dtype: fprop dtype ext_size: int >= beam_size, extension buffer size nbest_size: number of returned hyps, default is beam_size debug: log intermediate vlaues with tpu_summary.tensor() Returns: (loop_vars, dec_state, nbest) where nbest = (topk_ids, topk_len, topk_score) """ assert beam_size > 0 assert batch_size > 0 assert max_steps > 0 buf_size = beam_size * max_steps output_len = max_steps if prefix is None: assert prefix_len is None # Create prefix of start tokens. prefix = tf.zeros([batch_size, beam_size], dtype=tf.int32) prefix += tf.one_hot(beam_size - 1, beam_size, dtype=tf.int32) * bos_id prefix_len = tf.ones([batch_size], dtype=tf.int32) else: assert int(prefix.shape[0]) == batch_size, (batch_size, prefix.shape) assert int(prefix_len.shape[0]) == batch_size, (batch_size, prefix_len.shape) output_len += int(prefix.shape[1]) if debug: tpu_summary.tensor('prefix', prefix) tpu_summary.tensor('prefix_len', prefix_len) with tf.name_scope('init_state'): t = tf.constant(0) tgt_id = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_id += bos_id tgt_pos = tf.zeros([batch_size, beam_size], dtype=tf.int32) tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size), buf_size, dtype=fprop_dtype) hyp_score = tf.zeros([batch_size, beam_size], dtype=fprop_dtype) # penalize all hyps except the first hyp_score -= tf.cast(tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) nbest_size = nbest_size or beam_size nbest_score = tf.zeros([batch_size, nbest_size], dtype=fprop_dtype) nbest_score -= 1e9 nbest_score_norm = nbest_score nbest_mask = tf.zeros([batch_size, nbest_size, buf_size], dtype=fprop_dtype) with tf.name_scope('init_ext'): # Initialize the extension buffer. # # Extension buffer stores a (potentially large) set of 'extensions', # which consist of a hypothesis (represented by ext_mask) and next token # (represented by ext_id). At each decoder iteration, top_k extensions # from each hypothesis are added to the buffer and sorted by score. # # Then top beam_size extensions are removed from the buffer and used # in the next decoder iteration. And top 'ext_size' remaining extensions # are carried over to be possibly evaluated at a later step. # # As a result of this manipulation, the decoder is no longer restricted # to always compare hyps of the same token length at each iteration. # In particular, for a fixed length N it can generate more than beam_size # terminated hyps. # # Setting ext_size = 0 disables this feautre. if ext_size: ext_id = tf.zeros([batch_size, ext_size], dtype=tf.int32) ext_score = tf.zeros([batch_size, ext_size], dtype=fprop_dtype) ext_score -= 1e9 ext_mask = tf.zeros([batch_size, ext_size, buf_size], dtype=fprop_dtype) else: ext_size = ext_id = ext_score = ext_mask = 0 with tf.name_scope('init_prefix'): # rename prefix->pfx for shorter variables pfx = tf.cast(prefix, tf.int32) pfx_len = tf.cast(prefix_len, tf.int32) del prefix, prefix_len # Before the first call to dec_callback() the prefix shall be packed into # the tgt_id buffer as follows: # # [ - - - - - - P P P P P P P* - - - ] ^ # [ - - P P P P P P P P P P P* - - - ] | batch # [ - - - - - - - - - - - P P* - - - ] V # |<---- prefix len ----> |<-- beam --> # # The last meaningful token in the prefix (P*) # must be located at the same position in all batch rows. # # We then make one dec_callback() with full prefix (minus P*) # which will populate the initial dec_state # (for transformer -- self-attention key/value cache) # # The last block [batch, beam] then becomes the first tgt_id for the loop. pfx_max = int(pfx.shape[1]) pfx_mul = pfx_max // beam_size assert pfx_max == pfx_mul * beam_size, (pfx_max, pfx_mul, beam_size) pfx_time = tf.range(pfx_max) pfx_indexes = pfx_time - pfx_max + tf.expand_dims(pfx_len - 1, 1) pfx_pad = tf.cast(tf.greater_equal(pfx_indexes, 0), tf.int32) # Exclude final pfx token. pfx_id = tf.roll(pfx, shift=1, axis=-1) * pfx_pad pfx_last = pfx[:, -1] buf_time = tf.range(buf_size) pfx_time_mask = tf.cast( tf.less_equal(tf.expand_dims(buf_time, 0), tf.expand_dims(pfx_time, 1)), fprop_dtype) pfx_mask = tf.einsum('BQ,QK->BQK', tf.cast(pfx_pad, fprop_dtype), pfx_time_mask) # Remove padding. assert buf_size > pfx_max pfx_pad_long = tf.pad(pfx_pad, [(0, 0), (0, buf_size - pfx_max)], constant_values=1) pfx_mask *= tf.cast(tf.expand_dims(pfx_pad_long, axis=1), tf.float32) pfx_segment_id = pfx_pad pfx_pos = pfx_indexes * pfx_pad if debug: tpu_summary.tensor('pfx_id', pfx_id) tpu_summary.tensor('pfx_len', pfx_len) tpu_summary.tensor('pfx_pos', pfx_pos) tpu_summary.tensor('pfx_last', pfx_last) # Now call decoder with prefix minus P*: # 'dec_state' now shall contain the key/value cache for prefix tokens # (for transformer models), and 'logits' we can either discard or # roll into the initial hyp_score. Discard is simpler. with tf.name_scope('prefix_fprop'): # TODO(krikun): remove extra type checks assert (pfx_id.dtype == tf.int32), (pfx_id.dtype) assert (pfx_segment_id.dtype == tf.int32), (pfx_segment_id.dtype) assert (pfx_pos.dtype == tf.int32), (pfx_pos.dtype) assert (pfx_mask.dtype == fprop_dtype), (pfx_mask.dtype) assert (t.dtype == tf.int32), (t.dtype) logits, dec_state = dec_callback(pfx_id, pfx_segment_id, pfx_pos, pfx_mask, dec_state, t) del logits # Now construct the initial state for the rest of the beam search loop. # 'tgt_id' is simply 'pfx_last' padded to [batch, beam] shape # 'tgt_pos' is different for each batch row and is equal to prefix_len # 'tgt_segment_id' always 1 (no packing) # 'hyp_score' is 0 for beam=0 and negative for beam>=1 tgt_id = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( pfx_last, 1) tgt_pos = tf.zeros([batch_size, beam_size], tf.int32) + tf.expand_dims( (pfx_len - 1), 1) hyp_score = tf.zeros( [batch_size, beam_size], dtype=fprop_dtype) - tf.cast( tf.range(beam_size, dtype=tf.float32) * 1e5, dtype=fprop_dtype) # TODO(krikun) Here we make initial 't' constant and determined by the # shape of the prefix tensor 'pfx_max'. It is possible to make it dynamic # as t ~ max(pfx_len) / beam_size and this will more steps for beam search # however 'max' results in a very slow all-to-all for 'max' on 16x16 # and variable number of decoder steps may result in bad latency. t = tf.cast(tf.math.ceil(pfx_max / beam_size), tf.int32) # Initial tgt_mask is such that each token P* has attention on itself # (as usual) and on all prefix tokens before it, which are not padding. tgt_mask = tf.zeros([batch_size, beam_size, buf_size], dtype=fprop_dtype) tgt_mask += tf.cast( tf.expand_dims( tf.pad(pfx_pad, [[0, 0], [0, (buf_size - pfx_max)]]), 1), fprop_dtype) tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) if debug: tpu_summary.tensor('tgt_id', tgt_id) tpu_summary.tensor('tgt_pos', tgt_pos) tpu_summary.tensor('tgt_mask', tgt_mask) tpu_summary.tensor('t', t) with tf.name_scope('init_hist'): # h_tgt_id is used to recover topk_ids from nbest_mask h_tgt_id = tf.TensorArray(dtype=tf.int32, size=max_steps) h_tgt_pos = tf.TensorArray(dtype=tf.int32, size=max_steps) # When non-trivial prefix is present we also write prefix ids to # h_tgt_id so that the full sequence including prefix can be recovered # by unmask() below. When prefix is empty, pfx_id shape is [batch, 0] # and the loop below becomes a no-op. # TODO(krikun): maybe a tf.while_loop is more appropriate here. for i, x_i in enumerate(tf.split(pfx_id, pfx_mul, 1)): h_tgt_id = h_tgt_id.write(i, x_i) for i, x_i in enumerate(tf.split(pfx_pos, pfx_mul, 1)): h_tgt_pos = h_tgt_pos.write(i, x_i) hist = (h_tgt_id, h_tgt_pos) tf.logging.info('hist=%r', hist) nbest_hyps = (nbest_mask, nbest_score, nbest_score_norm) tf.logging.info('nbest_hyps=%r', nbest_hyps) ext = (ext_id, ext_score, ext_mask) tf.logging.info('ext=%r', ext) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) def loop_step(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (ext_id, ext_score, ext_mask) = ext (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.write(t, tgt_id, name='h_tgt_id') h_tgt_pos = h_tgt_pos.write(t, tgt_pos, name='h_tgt_pos') # not using tf.ones() here because of XLA compilation error tgt_segment_id = tgt_id * 0 + 1 logits, dec_state = dec_callback(tgt_id, tgt_segment_id, tgt_pos, tgt_mask, dec_state, t) # take predicted EOS score for each hyp and compute normalized score eos_score = hyp_score + tf.cast(logits[:, :, eos_id], hyp_score.dtype) def length_norm(t): t = tf.cast(t, fprop_dtype) alpha = length_norm_alpha tf.logging.info('length_norm.alpha=%r', alpha) return tf.math.pow((t + 5.) / 5., alpha) hyp_len = tgt_pos - tf.expand_dims((pfx_len - 1), -1) eos_score_norm = eos_score / length_norm(hyp_len) # update the n-best list nbest_hyps = update_nbest(nbest_hyps, (tgt_mask, hyp_score, eos_score_norm)) if debug: tpu_summary.tensor('eos_score', eos_score) tpu_summary.tensor('hyp_len', hyp_len) # take top k tokens for each hyp k = beam_size with tf.name_scope('topk1'): top_score, top_id = top_k_fn(logits, k) top_score = tf.cast(top_score, fprop_dtype) top_score += tf.expand_dims(hyp_score, -1) top_score -= 1e9 * tf.cast(tf.equal(top_id, eos_id), fprop_dtype) top_score = tf.reshape(top_score, [batch_size, beam_size * k]) top_id = tf.reshape(top_id, [batch_size, beam_size * k]) top_mask = tf.repeat(tgt_mask, beam_size, 1) if debug: tpu_summary.tensor('top_id', top_id) tpu_summary.tensor('top_score', top_score) # tpu_summary.tensor('top_mask', top_mask) with tf.name_scope('update_ext'): # combine top k tokens with extension buffer (if any) if ext_size: ext_id = tf.concat([ext_id, top_id], 1) ext_score = tf.concat([ext_score, top_score], 1) ext_mask = tf.concat([ext_mask, top_mask], 1) else: ext_id, ext_score, ext_mask = top_id, top_score, top_mask # sort by score ext_score, i = tf.math.top_k(ext_score, ext_size + beam_size) i1 = tf.one_hot(i, ext_size + beam_size * k, dtype=fprop_dtype) ext_mask = tf.einsum('bkt,bjk->bjt', ext_mask, i1) ext_id = einsum_i32('bk,bjk->bj', ext_id, i1) # pick top beam_size extensions to evaluate at next iteration if ext_size: hyp_score = ext_score[:, :beam_size] ext_score = ext_score[:, beam_size:] tgt_id = ext_id[:, :beam_size] ext_id = ext_id[:, beam_size:] tgt_mask = ext_mask[:, :beam_size] ext_mask = ext_mask[:, beam_size:] else: hyp_score, tgt_id, tgt_mask = ext_score, ext_id, ext_mask ext_score = ext_id = ext_mask = 0 tgt_pos = tf.reduce_sum(tgt_mask, -1) tgt_pos = tf.cast(tgt_pos, tf.int32) t += 1 with tf.name_scope('tgt_mask_extend'): tgt_mask += tf.one_hot(tf.range(beam_size) + t * beam_size, buf_size, dtype=fprop_dtype) ext = (ext_id, ext_score, ext_mask) hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) return loop_vars, dec_state def loop_cond(loop_vars, dec_state): # pylint: disable=missing-docstring tf.logging.info('loop_vars=%r', loop_vars) tf.logging.info('dec_state=%r', dec_state) if beam_gap is None: (t, _, _, _, _, _, _, _) = loop_vars return t < max_steps else: (t, _, _, _, _, nbest_hyps, _, _) = loop_vars (_, nbest_score, _) = nbest_hyps # stop early if all current hyps are significantly worse than nbest diff = tf.reduce_min( tf.reduce_min(nbest_score, -1) - tf.reduce_max(hyp_score, -1)) return tf.math.logical_and(t < max_steps, diff < beam_gap) with tf.name_scope('flat_beam_search_loop'): (loop_vars, dec_state) = tf.while_loop(loop_cond, loop_step, loop_vars=(loop_vars, dec_state), back_prop=False, swap_memory=False, maximum_iterations=max_steps) # flatten all tensorarrays into tensors (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) = loop_vars (nbest_mask, nbest_score, nbest_score_norm) = nbest_hyps (h_tgt_id, h_tgt_pos) = hist h_tgt_id = h_tgt_id.stack() h_tgt_pos = h_tgt_pos.stack() hist = (h_tgt_id, h_tgt_pos) loop_vars = (t, tgt_id, tgt_pos, tgt_mask, hyp_score, nbest_hyps, ext, hist) # recover topk_ids from nbest_mask and tgt_id history h = tf.transpose(h_tgt_id, [1, 0, 2]) h = tf.reshape(h, [batch_size, buf_size]) def unmask(h, m): with tf.name_scope('unmask'): tpu_summary.tensor('unmask_h', h) tpu_summary.tensor('unmask_m', m) t = tf.cumsum(m, -1) * m - 1 mh = einsum_i32('bkt,bt->bkt', m, h) t2 = tf.one_hot(tf.cast(t, tf.int32), output_len, dtype=fprop_dtype) x = einsum_i32('bkt,bktT->bkT', mh, t2) return tf.cast(x, h.dtype) topk_ids = unmask(h, nbest_mask) topk_len = tf.reduce_sum(nbest_mask, -1) topk_len = tf.cast(topk_len, tf.int32) # add eos, because nbest_mask does not encode eos topk_ids += eos_id * tf.one_hot(topk_len, output_len, dtype=tf.int32) topk_len += 1 topk_len = tf.minimum(topk_len, output_len) topk_score = nbest_score_norm nbest = (topk_ids, topk_len, topk_score) return loop_vars, dec_state, nbest
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all(input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, :max_seq_length] src_segment_pos = input_batch.segment_pos[:, :max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup(theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp(theta.position_emb, max_time) position_embs = tf.reshape(position_embs, [1, max_time, p.token_emb.embedding_dim]) # Position embeddings are simply added to token embeddings. input_embs += position_embs if p.individually_tagged_input: assert not p.packed_input # Look up tag embeddings; this assumes that the tags arriving on # input_batch.segment_ids (originating as common.source_segment_id # in the input NMTExample) have been reserved in the WPM vocabulary # as context tags, e.g. the ids for <src_token> and <ctxt_token> in # wide source context experiments. input_tags = py_utils.with_dependencies([ py_utils.assert_shape_match( tf.shape(input_batch.segment_ids), tf.shape(input_batch.ids)), py_utils.assert_equal(tf.rank(input_batch.segment_ids), 2) ], input_batch.segment_ids) tag_embeddings = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_tags, [-1])) tag_embeddings = tf.reshape(tag_embeddings, [-1, max_time, p.token_emb.embedding_dim]) # Concatenate the tag embeddings to the input embeddings, and then # project back to the original embedding dimensionality. concat_embs = tf.concat([input_embs, tag_embeddings], -1) input_embs = self.concat_emb_and_tag_proj.FProp( theta.concat_emb_and_tag_proj, concat_embs) if p.ln_input: input_embs = self.layer_norm_input.FProp(theta.layer_norm_input, input_embs) if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) summary_utils.histogram('input_embs', input_embs) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) summary_utils.histogram('emb_proj', input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where( tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap( encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def _SingleClassDecodeWithNMS(predicted_bboxes, classification_scores, nms_iou_threshold, score_threshold, max_boxes_per_class=None): """Perform NMS on predicted bounding boxes / associated logits. Args: predicted_bboxes: [batch_size, num_boxes, 7] float Tensor containing predicted bounding box coordinates. classification_scores: [batch_size, num_boxes, num_classes] float Tensor containing predicted classification scores for each box. nms_iou_threshold: IoU threshold to use when determining whether two boxes overlap for purposes of suppression. score_threshold: The score threshold passed to NMS that allows NMS to quickly ignore irrelevant boxes. max_boxes_per_class: The maximum number of boxes per example to emit. If None, this value is set to num_boxes from the shape of predicted_bboxes. Returns: predicted_bboxes: Filtered bboxes after NMS of shape [batch_size, num_classes, max_boxes_per_class, 7]. bbox_scores: A float32 Tensor with the score for each box of shape [batch_size, num_classes, max_boxes_per_class]. valid_mask: A float32 Tensor with 1/0 values indicating the validity of each box. 1 indicates valid, and 0 invalid. Tensor of shape [batch_size, num_classes, max_boxes_per_class]. """ utils_3d = detection_3d_lib.Utils3D() predicted_bboxes = py_utils.HasShape(predicted_bboxes, [-1, -1, 7]) batch_size, num_predicted_boxes, _ = py_utils.GetShape(predicted_bboxes) classification_scores = py_utils.HasShape( classification_scores, [batch_size, num_predicted_boxes, -1]) _, _, num_classes = py_utils.GetShape(classification_scores) if not isinstance(nms_iou_threshold, float): raise ValueError('Single class NMS only supports a scalar ' '`nms_iou_threshold`.') if not isinstance(score_threshold, float): raise ValueError('Single class NMS only supports a scalar ' '`score_threshold`.') if max_boxes_per_class is None: max_boxes_per_class = num_predicted_boxes # TODO(jngiam): Change to be per-class bboxes, and hence, per-class NMS, and # per-class thresholding. # [batch, num_predicted_boxes] nms_scores = tf.reduce_max(classification_scores, axis=-1) # Compute the most likely label by computing the highest class score from # the output of the sigmoid. likely_labels = tf.argmax(classification_scores, axis=-1) # When background is the most likely class for the box, mask out the scores # of that box from NMS scoring so the background boxes don't dominate the # NMS. nms_scores *= tf.cast(likely_labels > 0, tf.float32) # Compute NMS for every sample in the batch. nms_indices, valid_mask = utils_3d.BatchedNMSIndices( predicted_bboxes, nms_scores, nms_iou_threshold=nms_iou_threshold, score_threshold=score_threshold, max_num_boxes=max_boxes_per_class) # Reorder the box data and logits according to NMS scoring. predicted_bboxes = tf.array_ops.batch_gather(predicted_bboxes, nms_indices) classification_scores = tf.array_ops.batch_gather(classification_scores, nms_indices) # Now reformat the output of NMS to match the format of the # MultiClassOrientedDecodeWithNMS, which outputs a per class NMS result. # This takes the leading shape of # [batch_size, num_classes, max_boxes_per_class] for all outputs, which # means since this NMS is not class specific we need to tile the outputs # num_classes times or reorder the data such that its [batch, num_classes]. predicted_bboxes = tf.tile(predicted_bboxes[:, tf.newaxis, :, :], [1, num_classes, 1, 1]) classification_scores = tf.transpose(classification_scores, (0, 2, 1)) classification_scores = py_utils.HasShape( classification_scores, [batch_size, num_classes, max_boxes_per_class]) valid_mask = tf.tile(valid_mask[:, tf.newaxis, :], [1, num_classes, 1]) return predicted_bboxes, classification_scores, valid_mask
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if p.is_eval: outputs = _ToTuple(args) for (name, l) in self._before_layers: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) for (name, l) in self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tenors. input_tenors = _ToTuple(args) mini_batch_size = input_tenors[0].get_shape().as_list()[p.batch_dim] if p.state_dtype: state_dtype = p.state_dtype else: state_dtype = input_tenors[0].dtype if p.num_micro_batches > mini_batch_size: p.num_micro_batches = mini_batch_size micro_batch_size = mini_batch_size // p.num_micro_batches input_shapes = () for input_tensor in input_tenors: if input_tensor is not None: input_shape = input_tensor.get_shape().as_list() input_shape[p.batch_dim] = micro_batch_size input_shapes += (tf.TensorShape(input_shape),) else: input_shapes += (None,) state_shapes = self._CalculateOutputShapes(input_shapes) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 frop_inputs = [] for input_idx in range(len(state_shapes[i])): name = 's{}'.format(input_idx) if state_shapes[i][input_idx] is not None: inputs[name].set_shape(state_shapes[i][input_idx]) frop_inputs.append(inputs[name]) else: frop_inputs.append(None) with CellFnFropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, frop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] outputs = cell.FProp(theta, *frop_inputs) state1 = py_utils.NestedMap() state1[_MICRO_BATCH_STATE_NAME] = mb_tensor outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) init_state = py_utils.NestedMap() init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) for output_idx in range(len(state_shapes[cell_idx + 1])): name = 's{}'.format(output_idx) if state_shapes[cell_idx + 1][output_idx] is not None: init_state[name] = tf.zeros( state_shapes[cell_idx + 1][output_idx], dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = input_tenors for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) inputs = py_utils.NestedMap() gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) # TODO(huangyp, dehao): apply dehao's trick to reshape the input tensor # to [p.num_micro_batches, -1, 128]. for output_idx, output_tenor in enumerate(previous): name = 's{}'.format(output_idx) if output_tenor is not None: output_tenor = tf.stack( tf.split(output_tenor, p.num_micro_batches, axis=p.batch_dim)) inputs[name] = output_tenor output, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): output_tensors = [] for output_idx in range(len(state_shapes[-1])): state_shape = state_shapes[-1][output_idx] if state_shape is None: output_tensors.append(None) continue output_name = 's{}'.format(output_idx) output_tensor = output[output_name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, len(state_shape) + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) state_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, state_shape) output_tensors.append(output_tensor) tf.logging.info('pipeline output = {}'.format(output_tensors)) if len(output_tensors) == 1: return output_tensors[0] return tuple(output_tensors)
def FProp(self, theta, batch, state0=None): """Encodes source as represented by 'inputs' and 'paddings'. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. batch: A NestedMap with fields: - src_inputs - The inputs tensor. It is expected to be of shape [batch, time, feature_dim, channels]. - paddings - The paddings tensor. It is expected to be of shape [batch, time]. state0: Recurrent input state. Not supported/ignored by this encoder. Returns: A NestedMap containing - 'encoded': a feature tensor of shape [time, batch, depth] - 'padding': a 0/1 tensor of shape [time, batch] - 'state': the updated recurrent state - '${layer_type}_${layer_index}': The per-layer encoder output. Each one is a NestedMap containing 'encoded' and 'padding' similar to regular final outputs, except that 'encoded' from conv or conv_lstm layers are of shape [time, batch, depth, channels]. """ p = self.params inputs, paddings = batch.src_inputs, batch.paddings outputs = py_utils.NestedMap() with tf.name_scope(p.name): # Adding specAugmentation. if p.use_specaugment and not self.do_eval: inputs, paddings = self.specaugment.FProp( theta.specaugment, inputs, paddings) # Add a few extra padded timesteps at the end. This is for ensuring the # correctness of the conv-layers at the edges. if p.pad_steps > 0: # inplace_update() is not supported by TPU for now. Since we have done # padding on the input_generator, we may avoid this additional padding. assert not py_utils.use_tpu() inputs_pad = tf.zeros( inplace_ops.inplace_update(tf.shape(inputs), 1, p.pad_steps), inputs.dtype) paddings_pad = tf.ones( inplace_ops.inplace_update(tf.shape(paddings), 1, p.pad_steps), paddings.dtype) inputs = tf.concat([inputs, inputs_pad], 1, name='inputs') paddings = tf.concat([paddings, paddings_pad], 1) plots = [ summary_utils.PrepareSequenceForPlot( tf.transpose(inputs, [0, 1, 3, 2]), paddings, 'inputs') ] conv_out = inputs out_padding = paddings for i, conv_layer in enumerate(self.conv): conv_out, out_padding = conv_layer.FProp( theta.conv[i], conv_out, out_padding) if p.extra_per_layer_outputs: conv_out *= (1.0 - out_padding[:, :, tf.newaxis, tf.newaxis]) outputs['conv_%d' % i] = py_utils.NestedMap( encoded=tf.transpose(conv_out, [1, 0, 2, 3]), # to [t, b, d, c] padding=tf.transpose(out_padding)) plots.append( summary_utils.PrepareSequenceForPlot( tf.transpose(conv_out, [0, 1, 3, 2]), out_padding, 'conv_%d_out' % i)) def TransposeFirstTwoDims(t): first_dim = tf.shape(t)[0] second_dim = tf.shape(t)[1] t_new = tf.transpose( tf.reshape(t, [first_dim, second_dim, -1]), [1, 0, 2]) t_shape_new = tf.concat([[second_dim], [first_dim], tf.shape(t)[2:]], 0) return tf.reshape(t_new, t_shape_new) # Now the conv-lstm part. conv_lstm_out = conv_out conv_lstm_out_padding = out_padding for i, (rnn, cnn) in enumerate( zip(self.conv_lstm_rnn, self.conv_lstm_cnn)): conv_lstm_in = conv_lstm_out # Move time dimension to be the first. conv_lstm_in = TransposeFirstTwoDims(conv_lstm_in) conv_lstm_in = tf.expand_dims(conv_lstm_in, 2) conv_lstm_in_padding = tf.expand_dims( tf.transpose(conv_lstm_out_padding), 2) lstm_out = rnn.FProp(theta.conv_lstm_rnn[i], conv_lstm_in, conv_lstm_in_padding) # Move time dimension to be the second. cnn_in = TransposeFirstTwoDims(lstm_out) cnn_in = tf.squeeze(cnn_in, 2) cnn_in_padding = conv_lstm_out_padding cnn_out, cnn_out_padding = cnn.FProp(theta.conv_lstm_cnn[i], cnn_in, cnn_in_padding) conv_lstm_out, conv_lstm_out_padding = cnn_out, cnn_out_padding if p.extra_per_layer_outputs: conv_lstm_out *= ( 1.0 - conv_lstm_out_padding[:, :, tf.newaxis, tf.newaxis]) outputs['conv_lstm_%d' % i] = py_utils.NestedMap( encoded=tf.transpose(conv_lstm_out, [1, 0, 2, 3]), # to [t, b, d, c] padding=tf.transpose(conv_lstm_out_padding)) plots.append( summary_utils.PrepareSequenceForPlot( conv_lstm_out, conv_lstm_out_padding, 'conv_lstm_%d_out' % i)) # Need to do a reshape before starting the rnn layers. conv_lstm_out = py_utils.HasRank(conv_lstm_out, 4) conv_lstm_out_shape = tf.shape(conv_lstm_out) new_shape = tf.concat([conv_lstm_out_shape[:2], [-1]], 0) conv_lstm_out = tf.reshape(conv_lstm_out, new_shape) if self._first_lstm_input_dim_pad: conv_lstm_out = tf.pad( conv_lstm_out, [[0, 0], [0, 0], [0, self._first_lstm_input_dim_pad]]) conv_lstm_out = py_utils.HasShape( conv_lstm_out, [-1, -1, self._first_lstm_input_dim]) # Transpose to move the time dimension to be the first. rnn_in = tf.transpose(conv_lstm_out, [1, 0, 2]) rnn_padding = tf.expand_dims(tf.transpose(conv_lstm_out_padding), 2) # rnn_in is of shape [time, batch, depth] # rnn_padding is of shape [time, batch, 1] # Now the rnn layers. num_skips = 0 for i in range(p.num_lstm_layers): rnn_out = self.rnn[i].FProp(theta.rnn[i], rnn_in, rnn_padding) residual_index = i - p.residual_start + 1 if p.residual_start > 0 and residual_index >= 0: if residual_index % p.residual_stride == 0: residual_in = rnn_in if residual_index % p.residual_stride == p.residual_stride - 1: # Highway skip connection. if p.highway_skip: rnn_out = self.highway_skip[num_skips].FProp( theta.highway_skip[num_skips], residual_in, rnn_out) num_skips += 1 else: # Residual skip connection. rnn_out += py_utils.HasShape( residual_in, tf.shape(rnn_out)) if p.project_lstm_output and (i < p.num_lstm_layers - 1): # Projection layers. rnn_out = self.proj[i].FProp(theta.proj[i], rnn_out, rnn_padding) if i == p.num_lstm_layers - 1: rnn_out *= (1.0 - rnn_padding) if p.extra_per_layer_outputs: rnn_out *= (1.0 - rnn_padding) outputs['rnn_%d' % i] = py_utils.NestedMap( encoded=rnn_out, padding=tf.squeeze(rnn_padding, [2])) # Stacking layer connection. if p.layer_index_before_stacking == i: # Stacking layer expects input tensor shape as [batch, time, feature]. # So transpose the tensors before and after the layer. rnn_out, rnn_padding = self.stacking.FProp( tf.transpose(rnn_out, [1, 0, 2]), tf.transpose(rnn_padding, [1, 0, 2])) rnn_out = tf.transpose(rnn_out, [1, 0, 2]) rnn_padding = tf.transpose(rnn_padding, [1, 0, 2]) plots.append( summary_utils.PrepareSequenceForPlot( tf.transpose(rnn_out, [1, 0, 2]), tf.transpose(rnn_padding, [1, 0, 2]), 'rnn_%d_out' % i)) rnn_in = rnn_out final_out = rnn_in summary_utils.PlotSequenceFeatures(list(reversed(plots)), 'encoder_example', xlabel='Time') outputs['encoded'] = final_out outputs['padding'] = tf.squeeze(rnn_padding, [2]) outputs['state'] = py_utils.NestedMap() return outputs
def _process(record): num = tf.py_func(pickle.loads, [record], tf.int32) bucket_key = tf.shape(num)[0] return [num, tf.transpose(num, [1, 0, 2])], bucket_key
def Transpose(paddings): paddings = paddings if isinstance(paddings, list) else [paddings] return [tf.transpose(p) for p in paddings]
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, num_hyps_per_beam=1) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(theta=decoder_theta, random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, allow_implicit_capture=True) result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend beam search hyps for one step. | num_beams = Number of source sequences to be decoded. | num_hyps_per_beam = Number of hyps to keep per source sequence. | num_hyps = num_beams * num_hyps_per_beam | src_seq_len = Number of time steps in the source sequence. | src_batch = Number of examples in the source sequence. | tgt_seq_len = Maximum allowed time steps in the target sequence. | tgt_batch = num_hyps_per_beam * src_batch Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next beam search step, (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, all_done) = ops.beam_search_step( tf.cast(bs_results.log_probs, dtype=p.dtype), tf.cast(bs_results.atten_probs, dtype=p.dtype), best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, bs_results.is_last_chunk if self._model_uses_eoc_id else [], cur_step, eoc_id=p.target_eoc_id, eos_id=p.target_eos_id, beam_size=p.beam_size, num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, allow_empty_terminated_hyp=p.allow_empty_terminated_hyp, ensure_full_beam=p.ensure_full_beam, force_eos_in_last_step=p.force_eos_in_last_step, local_eos_threshold=p.local_eos_threshold) new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids)) new_step_ids.set_shape(step_ids.get_shape()) old_hyp_ids = tf.reshape( tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) new_bs_states = (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def ComputePredictions(self, theta, input_batch): p = self.params batch_size = p.input.batch_size self._shape_batch(input_batch) # Prepend SOS token, this is not done by the Transformer layer for you # since this is usually done by the input pipeline in Babelfish. pronunciation = self._AddStartToken(input_batch.pronunciation) if p.use_neighbors: spellings = input_batch.neighbor_spellings pronunciations = input_batch.neighbor_pronunciations inp = { "ids": input_batch.spelling, } if (p.use_neighbors and p.also_shuffle_neighbors and (p.neigh_att_type == "CONCAT" or p.use_neigh_id_emb)): # If we use neighbor IDs, shuffle the neighbours to stop the model # overfitting to the ordering of the neighbours. # Concat then shuffle and split so that the spelling and pronunciation # are shuffled the same way and the IDs are aligned. neighbor_info = tf.concat([spellings, pronunciations], axis=-1) # Transpose the max_neighbors dimension to the front and shuffle. neighbor_info = tf.transpose( tf.random.shuffle(tf.transpose(neighbor_info, (1, 2, 0))), (2, 0, 1)) spellings, pronunciations = ( neighbor_info[:, :, :p.max_spelling_len], neighbor_info[:, :, p.max_spelling_len:]) if p.use_neighbors and p.neigh_att_type == "CONCAT": # Interleave and flatten the neighbours info # ->(batch_size, max_neighbors, max_spelling_len + max_pronunciation_len) neigh_info = tf.concat([spellings, pronunciations], axis=2) # ->(batch_size, max_neighbors*(max_spelling_len + max_pronunciation_len)) neigh_info = tf.reshape(neigh_info, (batch_size, -1)) inp["ids"] = tf.concat([inp["ids"], neigh_info], axis=1) # If we are just concatenating everything then the main encoder needs # neighbors IDs. neigh_ids = tf.range(p.max_neighbors)[:, tf.newaxis] neigh_ids = tf.tile( neigh_ids, (batch_size, p.max_spelling_len + p.max_pronunciation_len)) neigh_ids = tf.reshape(neigh_ids, (batch_size, -1)) # Add the ids for the main input main_ids = tf.tile([[p.max_neighbors]], (batch_size, p.max_spelling_len)) inp["task_ids"] = tf.concat([main_ids, neigh_ids], axis=1) inp["paddings"] = self._GetPaddings(inp["ids"], dtype=tf.int32) enc_out = self.encoder.FProp(theta.encoder, py_utils.NestedMap(inp)) # Auxiliary inputs that the decoder can attend to, currently can be # neighbour summaries. aux_inputs = [] aux_paddings = [] if p.use_neighbors and p.neigh_att_type != "CONCAT": neigh_enc, padding = self._GetAxiliaryNeighInputs( spellings, pronunciations, enc_out, theta, batch_size) aux_inputs.extend(neigh_enc) aux_paddings.extend(padding) if aux_inputs: aux_inputs = tf.concat(aux_inputs, axis=0) aux_paddings = tf.concat(aux_paddings, axis=0) if p.aux_dropout_prob and not self.do_eval: aux_inputs = tf.nn.dropout( aux_inputs, p.aux_dropout_prob, noise_shape=(aux_inputs.get_shape().as_list()[0], batch_size, 1)) enc_out.encoded = tf.concat([enc_out.encoded, aux_inputs], axis=0) enc_out.padding = tf.concat([enc_out.padding, aux_paddings], axis=0) enc_out.embedded_inputs = None # to verify this is not used predictions = self.decoder.ComputePredictions( theta.decoder, enc_out, py_utils.NestedMap({ "ids": pronunciation, "paddings": self._GetPaddings(pronunciation), "weights": tf.ones_like(input_batch.pronunciation, dtype=tf.float32), })) beam_out = self.decoder.BeamSearchDecode(enc_out, p.beam_size) top_ids = tf.reshape(beam_out.topk_ids, [batch_size, -1, p.max_pronunciation_len]) # Just take the top beam decodings top_ids = top_ids[:, 0, :] if p.is_inference: self.BuildInferenceInfo(top_ids, input_batch.pronunciation, enc_out) self.per_example_tensors["beam_scores"] = beam_out.topk_scores self.per_example_tensors["hyp"] = top_ids self.per_example_tensors["cognate_id"] = input_batch.cognate_id self.per_example_tensors["inp"] = input_batch.spelling self.per_example_tensors["ref"] = input_batch.pronunciation if p.use_neighbors: # Note that cannot return None! self.per_example_tensors[ "neighbor_spellings"] = input_batch.neighbor_spellings self.per_example_tensors[ "neighbor_pronunciations"] = input_batch.neighbor_pronunciations self.prediction_values = predictions predictions.batch = input_batch return predictions
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def _BroadcastAcrossPoints(z): return tf.transpose(tf.tile(z, [1, num_points]))
def _ConstructWarpMatrix(self, batch_size, matrix_size, origin, destination, choose_range, dtype): """Returns warp matrices according to origin, destination and choose_range. This function constructs a batch of warp matrices which maps the batch of origin points to the batch of destination points with fixed boundary coordinates at 0 and choose_range. The warping function, defined by the origin anchor point `origin`, the destination of the origin anchor point `destination` and the length of the domain in the warping axis `choose_range` is a piecewise linear map that fixes the points 0 and `choose_range` and maps `origin` to `destination`. For the warping matrix to be non-singular, destination must lie in the range 1<= destination <= choose_range - 1, so a destination out of this range is adjusted to be in this range before the warping matrix is constructed. The warping map can be explicitly written by first defining the slopes: 1) slope_0 = origin / destination. 2) slope_1 = (choose_range - origin) / (choose_range - destination). 3) slope_2 = 1.0. Then the origin point orig_i of the mapped coordinate i is given by: 1) i < destination: orig_i = slope_0 * i. 2) destination <= i < choose_range: orig_i = slope_1 * i - (slope_1 - slope_0) * destination. 3) i >= choose_range: orig_i = i. Denoting n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: 1) j = n_i: 1 - n_i + orig_i. 2) j = n_i - 1: n_i - orig_i. 3) Otherwise: 0. Applying the warp matrix to an array of pixels, i.e., warped_pixel[i] = sum_j warp[i][j] * pixel[j], one would get warped_pixel[i] = (n_i-orig_i) pixel[n_i-1] + (1-n_i+orig_i) pixel[n_i]. Args: batch_size: Batch size. Integer number. matrix_size: Dimension of the vector space the warp matrix is applied to. Integer number. origin: Origin anchor point for warping. Tensor of shape (batch_size,) and data type dtype. destination: Destination of the origin anchor point upon warping. Tensor of shape (batch_size,) and data type dtype. choose_range: Range within which the warp reference points must lie. Tensor of shape (batch_size,) data type dtype. dtype: Data type of origin, destination, choose_range and the output warp matrix. Returns: warp_matrix: An array of fixed size warp matrices with shape (batch_size, matrix_size, matrix_size). """ p = self.params # Entries of destination must be in the range # 1 <= destination <= choose_range - 1 # for warp matrix to have non-singular values. destination = tf.minimum(tf.maximum(destination, 1.0), choose_range - 1.0) # Construct piece-wise linear function fixing boundary points # specified by zero, choose_range and matrix size and maps # the origin anchor point to the destination. destination_bc = tf.broadcast_to(destination, (matrix_size, batch_size)) destination_bc = tf.transpose(destination_bc) choose_range_bc = tf.broadcast_to(choose_range, (matrix_size, batch_size)) choose_range_bc = tf.transpose(choose_range_bc) # Slopes of piece-wise linear function. slope_0 = origin / destination slope_1 = (choose_range - origin) / (choose_range - destination) slope_2 = 1.0 # x is a batch of origin matrices. # The origin matrix is the matrix such that # origin[i][j] = Origin coordinate of coordinate i for the warp map. # Denoting the destination of the origin anchor point in the # warp map as "dest," the origin coordinate of point i is given by: # 1) i < dest: slope_0 * i. # 2) dest <= i < choose_range: slope_1 * i - (slope_1 - slope_0) * dest. # 3) i >= choose_range: i. x = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size)) x = (self.EinsumBBmBm(slope_0, x) + self.EinsumBBmBm( slope_1 - slope_0, tf.nn.relu(x - destination_bc)) + self.EinsumBBmBm(slope_2 - slope_1, tf.nn.relu(x - choose_range_bc))) x = tf.broadcast_to(x, (matrix_size, batch_size, matrix_size)) x = tf.transpose(x, perm=[1, 2, 0]) # y is a batch of coordinate matrices. # A coordinate matrix is a matrix such that # coordinate[i][j] = j. y = tf.broadcast_to(tf.cast(tf.range(matrix_size), dtype=dtype), (batch_size, matrix_size, matrix_size)) # Warp matrix is obtained by applying hat function element-wise to (x-y). # Denoting the origin point of i under the warp map as orig_i, # and n_i = ceil(orig_i), the warp matrix element warp[i][j] is given by: # 1) j = n_i: 1 - n_i + orig_i. # 2) j = n_i - 1: n_i - orig_i. # 3) Otherwise: 0. # Applying the warp matrix to pixels, i.e., # warped_pixel[i] = sum_j warp[i][j] * original_pixel[j], one would get # warped_pixel[i] = (n_i - orig_i) * original_pixel[n_i-1] # + (1 - n_i + orig_i) * original_pixel[n_i]. warp_matrix = x - y warp_matrix = _hat(warp_matrix) if p.fprop_dtype is not None and p.fprop_dtype != dtype: warp_matrix = tf.cast(warp_matrix, p.fprop_dtype) return warp_matrix
def testDecoderFPropWithAdapters(self): """Create decoder with adapters, and verify that FProp runs.""" with self.session(use_gpu=False): tf.random.set_seed(8372749040) params = _DecoderParams( num_rnn_layers=2, vn_config=py_utils.VariationalNoiseParams( None, True, False, seed=12345)) params.rnn_cell_dim = 3 params.adapter_layer_tpl.Set( bottleneck_dim=4, num_tasks=16, projection_params_init=py_utils.WeightInit.Gaussian(0.01)) params.adapter_task_id_field = 'domain_ids' dec = params.Instantiate() src_seq_len = 5 src_enc = tf.random.normal([src_seq_len, 2, 8], seed=982774838, dtype=py_utils.FPropDtype(params)) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=py_utils.FPropDtype(params)) domain_ids = tf.constant(np.random.randint(low=0, high=16, size=[2])) encoder_outputs = py_utils.NestedMap( encoded=src_enc, padding=src_enc_padding, domain_ids=domain_ids) # shape=[4, 5] target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) # shape=[4, 5] target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) # shape=[4, 5] target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 0]], dtype=py_utils.FPropDtype(params))) target_transcripts = tf.constant(['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings # ids/labels/weights/paddings are all in [batch, time] shape. targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) decoder_outputs = dec.FPropDefaultTheta(encoder_outputs, targets) metrics = decoder_outputs.metrics per_sequence_loss = decoder_outputs.per_sequence['loss'] self.assertIn('fraction_of_correct_next_step_preds', metrics) self.evaluate(tf.global_variables_initializer()) metrics_val, per_sequence_loss_val = self.evaluate( [metrics, per_sequence_loss]) tf.logging.info('metrics=%s, per_sequence_loss=%s', metrics_val, per_sequence_loss_val) self.assertEqual(metrics_val['loss'], metrics_val['log_pplx']) # Target batch size is 4. Therefore, we should expect 4 here. self.assertEqual(per_sequence_loss_val.shape, (4,))
def _testDecoderFPropGradientCheckerHelper(self, func_inline=False): config = tf.ConfigProto(graph_options=tf.GraphOptions( optimizer_options=tf.OptimizerOptions( do_function_inlining=func_inline))) with self.session(use_gpu=False, config=config) as sess: tf.set_random_seed(8372749040) np.random.seed(274854) vn_config = py_utils.VariationalNoiseParams(None, False, False) p = self._DecoderParams(vn_config) p.dtype = tf.float64 dec = p.Instantiate() src_seq_len = 5 src_enc = tf.constant(np.random.uniform(size=(src_seq_len, 2, 8)), tf.float64) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=tf.float64) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) target_ids = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 15], [5, 6, 7, 8], [10, 5, 2, 5]], dtype=tf.int32)) target_labels = tf.transpose( tf.constant([[0, 1, 2, 3], [1, 2, 3, 4], [10, 11, 12, 13], [5, 7, 8, 10], [10, 5, 2, 4]], dtype=tf.int32)) target_paddings = tf.transpose( tf.constant([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [1, 1, 1, 1]], dtype=tf.float64)) target_transcripts = tf.constant( ['abcd', 'bcde', 'klmp', 'fghi', 'kfcf']) target_weights = 1.0 - target_paddings targets = py_utils.NestedMap({ 'ids': target_ids, 'labels': target_labels, 'weights': target_weights, 'paddings': target_paddings, 'transcripts': target_transcripts, }) metrics = dec.FPropDefaultTheta(encoder_outputs, targets).metrics loss = metrics['loss'][0] all_vars = tf.trainable_variables() grads = tf.gradients(loss, all_vars) def DenseGrad(var, grad): if isinstance(grad, tf.Tensor): return grad elif isinstance(grad, tf.IndexedSlices): return tf.unsorted_segment_sum(grad.values, grad.indices, tf.shape(var)[0]) dense_grads = [DenseGrad(x, y) for (x, y) in zip(all_vars, grads)] tf.global_variables_initializer().run() test_utils.CompareToGoldenSingleFloat(self, 3.458078, loss.eval()) # Second run to make sure the function is determistic. test_utils.CompareToGoldenSingleFloat(self, 3.458078, loss.eval()) symbolic_grads = [x.eval() for x in dense_grads if x is not None] numerical_grads = [] for v in all_vars: numerical_grads.append( test_utils.ComputeNumericGradient(sess, loss, v)) for x, y in zip(symbolic_grads, numerical_grads): self.assertAllClose(x, y)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition(theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += position_embs if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) encoded, padding = self.transformer_stack.FProp(theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose(segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def FarthestPointSampler(points, padding, num_sampled_points, precomputed_squared_distance=None, num_seeded_points=0, random_seed=None): """Samples num_sampled_points from points using farthest point sampling. Algorithm: 1. Start by selecting a random point and adding to a selected set. 2. For all remaining points, find the furthest point from those selected. 3. Add furthest point to selected. 4. Repeat 2-3 until num_sampled_points are selected. More details at https://en.wikipedia.org/wiki/Farthest-first_traversal This output of this function can be used with tf.batch_gather to extract the desired points, for example: tf.batch_gather(points, sampled_idx) Args: points: floating point tf.Tensor of shape [N, P1, dims] padding: A floating point tf.Tensor of shape [N, P1] with 0 if the point is real, and 1 otherwise. num_sampled_points: integer number of points to sample. precomputed_squared_distance: optional tf.Tensor of shape [N, P1, P1] of distances between each point. if None, distances will be computed on the fly. num_seeded_points: If num_seeded_points > 0, then the first num_seeded_points in points are considered to be seeded in the FPS sampling. Note that we assume that these points are *not* padded, and do not check padding when seeding them. random_seed: optional integer random seed to use with all the random ops. Returns: A tuple of tf.Tensors (sampled_idx, closest_idx) of types (tf.int32, tf.int32). sampled_idx is of shape [N, num_sampled_points] representing the indices selected using the sampler. This will have range of [0, P1]. closest_idx is of shape [N, P1] representing the indices of the closest sampled points for each input point. closest_idx is used in PCNN as part of the pooling operation: each point is assigned to the closest sampled point and a max is taken over them. This will have a range of [0, P2] with the index of the closest sampled point that remains. """ points = py_utils.HasRank(points, 3) batch_size, num_points, dims = py_utils.GetShape(points, 3) points = py_utils.with_dependencies( [py_utils.assert_greater_equal(num_points, num_sampled_points)], points) # Add a tiny bit of noise to the distance matrix or points so all # points are unique. This will also ensure true repeated points # like padded points are only selected after all valid points are selected. if precomputed_squared_distance is not None: precomputed_squared_distance = py_utils.HasShape( precomputed_squared_distance, [batch_size, num_points, num_points]) precomputed_squared_distance += tf.random.uniform( (batch_size, num_points, 1), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) else: points += tf.random.uniform((batch_size, num_points, dims), minval=1e-6, maxval=1e-5, dtype=tf.float32, seed=random_seed) # TensorArray to store the sampled indices in the loop. sampled_idx = tf.TensorArray(tf.int32, num_sampled_points) # Initialize distance_to_selected to inf for all points. distance_to_selected = float('inf') * tf.ones((batch_size, num_points)) # For tracking the index to the closest selected point. closest_idx = tf.zeros((batch_size, num_points), dtype=tf.int32) # Current loop index counter. curr_idx = tf.constant(0, dtype=tf.int32) # Get number of valid points (1 is padded, so num_points - num_padded). num_valid_points = tf.cast(tf.cast(num_points, dtype=tf.float32) - tf.reduce_sum(padding, axis=1), dtype=tf.int32) def _BodyFn(curr_idx, distance_to_selected, sampled_idx, closest_idx): """Loop body for farthest point sampler.""" def _GetRandomRealPoint(): """Select the first point. For the first point, we want any random real (non padded) point, so we create a random values per point, and then set all padded ones to some large value (more than the maxval). We then take the min per batch element to get the first points. Returns: Tensor containing the index of a random point selected for each example in the batch. """ random_values = tf.random.uniform((batch_size, num_points), minval=0, maxval=1, dtype=tf.float32, seed=random_seed) random_values = tf.where(tf.equal(padding, 0.0), random_values, padding * 10) return tf.argmin(random_values, axis=1, output_type=tf.int32) def _GetFurthestPoint(): """Get point that is furthest from those already selected. We also bias the sampling towards real points by setting the distance to padded points negative until we are out of real points. Returns: Tensor containing the index of the next farthest point selected for each example in the batch. """ # Set padded points distance to negative so they aren't selected. padding_masked_distance_to_selected = tf.where( tf.equal(padding, 0.0), distance_to_selected, -1.0 * tf.ones( (batch_size, num_points), dtype=tf.float32)) # But only do this when we still have valid points left. padding_masked_distance_to_selected = tf.where( tf.less(curr_idx, num_valid_points), padding_masked_distance_to_selected, distance_to_selected) return tf.argmax(padding_masked_distance_to_selected, axis=-1, output_type=tf.int32) def _GetSeededPoint(): """Select a seeded point. Seeded points are assumed to be at the beginning of the original points. Returns: Tensor containing the index of the next seeded point to select for each example in the batch. """ return tf.ones((batch_size, ), dtype=tf.int32) * curr_idx # Select indices for this loop iteration. def _Seeded(): return tf.cond(tf.less(curr_idx, num_seeded_points), _GetSeededPoint, _GetFurthestPoint) def _Real(): return tf.cond(tf.equal(curr_idx, 0), _GetRandomRealPoint, _GetFurthestPoint) new_selected = tf.cond(tf.greater(num_seeded_points, 0), _Seeded, _Real) sampled_idx = sampled_idx.write(curr_idx, new_selected) # Extract the distance to the latest point selected to update # distance_to_selected. new_selected_gather_idx = tf.stack( [tf.range(batch_size), new_selected], axis=1) if precomputed_squared_distance is not None: new_distance = tf.gather_nd(precomputed_squared_distance, new_selected_gather_idx) else: new_points = tf.reshape( tf.gather_nd(points, new_selected_gather_idx), [batch_size, 1, dims]) new_distance = tf.reshape( SquaredDistanceMatrix(points, new_points), [batch_size, num_points]) is_newly_closest = tf.less(new_distance, distance_to_selected) distance_to_selected = tf.minimum(distance_to_selected, new_distance) # Track the index to the closest selected point. new_selected_tiled = tf.tile([[curr_idx]], [batch_size, num_points]) closest_idx = tf.cond( tf.equal(curr_idx, 0), # At the first loop iteration, the init points are the closest. lambda: new_selected_tiled, # Otherwise, update with the new points based on the distances. lambda: tf.where(is_newly_closest, new_selected_tiled, closest_idx) ) return curr_idx + 1, distance_to_selected, sampled_idx, closest_idx _, _, sampled_idx, closest_idx = tf.while_loop( lambda curr_idx, *args: tf.less(curr_idx, num_sampled_points), _BodyFn, loop_vars=(curr_idx, distance_to_selected, sampled_idx, closest_idx), back_prop=False, maximum_iterations=num_sampled_points) sampled_idx = sampled_idx.stack() # num_sampled_points x n sampled_idx = tf.transpose(sampled_idx, [1, 0]) if isinstance(batch_size, int) and isinstance(num_sampled_points, int): sampled_idx.set_shape((batch_size, num_sampled_points)) return sampled_idx, closest_idx
def ComputePredictions(self, encoder_outputs, pronunciations, is_inference=False): """Computes the predictions from the encoder_outputs, updating losses. Despite the name, this function does the bulk of the decoding and loss computation, incrementing the loss at each time step. Args: encoder_outputs: a NestedMap consisting of outputs of the FeatureNeighborhoodEncoder with encoded - encoding of the input spelling neighbor_pronunciations_encoded - encodings of the neighbor prons neighbor_pronunciations_encoded - encodings of the neighbor spellings state - encoder state to which has been added dec_input - seed output for the decoder [*, 1] tensor consisting of sentence start indices (corresponding to "<s>") pronunciations: NestedMap with pronunciations - [*, max_pronunciation_len] tensor of pronunciations is_inference: If False then uses teacher forcing else does autoregression. Returns: NestedMap with loss, per_sequence_losses,labels, a [*, max_pronunciation_len] tensor of predictions, and attention ([*, max_pronunciation_len, max_spelling_len]), and neighbor_attention ([*, max_pronunciation_len, max_neighbors]) tensors, along with the raw batch passed through from the encoder. """ p = self.params targets = pronunciations.pronunciations t_len = int(targets.get_shape().as_list()[1]) t_idx = tf.constant(0) attention = tf.TensorArray(dtype=tf.float32, size=t_len) neighbor_attention = tf.TensorArray(dtype=tf.float32, size=t_len) outputs = tf.TensorArray(dtype=tf.float32, size=t_len) loop_cond = lambda t_idx, ts, *_: tf.less(t_idx, t_len) dec_input = tf.convert_to_tensor([p.start] * p.input.batch_size) state = encoder_outputs.state # pylint: disable=missing-docstring def loop_body(t_idx, dec_input, attention, neighbor_attention, state, outputs): decoder_result = self.Decode(encoder_outputs, dec_input, state) outputs = outputs.write(t_idx, decoder_result.predictions) attention = attention.write(t_idx, decoder_result.attention_weights) neighbor_attention = neighbor_attention.write( t_idx, tf.cast(decoder_result.neighbor_attention_weights, dtype=tf.float32)) if is_inference: dec_input = tf.cast(tf.argmax(decoder_result.predictions, 1), tf.int32) else: dec_input = targets[:, t_idx] t_idx = t_idx + 1 state = decoder_result.state return t_idx, dec_input, attention, neighbor_attention, state, outputs _, _, attention, neighbor_attention, state, outputs = tf.while_loop( loop_cond, loop_body, loop_vars=[ t_idx, dec_input, attention, neighbor_attention, state, outputs ]) outputs = tf.transpose(outputs.stack(), [1, 0, 2]) labels = tf.argmax(outputs, axis=-1) mask = tf.cast(tf.math.logical_not(tf.math.equal(targets, 0)), dtype=tf.float32) loss = self._loss_object(targets, outputs, sample_weight=mask) loss = tf.reduce_sum(loss, axis=1) per_sequence_losses = (loss / t_len) loss = tf.reduce_mean(per_sequence_losses) predictions = py_utils.NestedMap() predictions.loss = loss predictions.per_sequence_losses = per_sequence_losses predictions.labels = labels predictions.attention = tf.transpose(tf.squeeze(attention.stack()), perm=[1, 0, 2]) if p.use_neighbors: predictions.neighbor_attention = tf.transpose(tf.squeeze( neighbor_attention.stack()), perm=[1, 0, 2]) else: predictions.neighbor_attention = tf.squeeze( neighbor_attention.stack()) # Expose this for subsequent data analysis predictions.batch = encoder_outputs.batch return predictions