def _InputBatch(self): targets = tf.ones([self.params.batch_size, 1024], dtype=tf.int32) input_batch = py_utils.NestedMap() input_batch.tgt = py_utils.NestedMap() input_batch.tgt.ids = tf.roll(targets, 1, axis=1) input_batch.tgt.labels = targets input_batch.tgt.segment_ids = tf.minimum(targets, 1) input_batch.tgt.segment_pos = targets input_batch = input_batch.Transform( lambda t: tf.ensure_shape(t, (self.params.batch_size, 1024))) return input_batch
def __init__(self, params): super().__init__(params) p = self.params (utt_ids, src_frames, src_paddings), self._bucket_keys = self._BuildDataSource() self._sample_ids = utt_ids src_frames, src_paddings = self._MaybePadSourceInputs( src_frames, src_paddings) # We expect src_inputs to be of shape # [batch_size, num_frames, feature_dim, channels]. src_frames = tf.expand_dims(src_frames, axis=-1) if p.pad_to_max_seq_length: assert p.source_max_length assert p.target_max_length if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit): # Set the input batch size as an int rather than a tensor. src_frames_shape = (self.InfeedBatchSize(), p.source_max_length, p.frame_size, 1) src_paddings_shape = (self.InfeedBatchSize(), p.source_max_length) else: tf.logging.warning( 'Could not set static input shape since not all bucket batch sizes ' 'are the same:', p.bucket_batch_limit) src_frames_shape = None src_paddings_shape = None src_frames = py_utils.PadBatchDimension(src_frames, self.InfeedBatchSize(), 0) src_paddings = py_utils.PadBatchDimension(src_paddings, self.InfeedBatchSize(), 1) self._sample_ids = py_utils.PadBatchDimension( self._sample_ids, self.InfeedBatchSize(), self._sample_ids.min) src_frames = py_utils.PadSequenceDimension(src_frames, p.source_max_length, 0, shape=src_frames_shape) src_paddings = py_utils.PadSequenceDimension( src_paddings, p.source_max_length, 1, shape=src_paddings_shape) self._sample_ids = tf.ensure_shape(self._sample_ids, self.InfeedBatchSize()) src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings) self._src = src
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """Calls an embedding lookup and updates the state of token history. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. prepared_inputs: unused. step_inputs: A NestedMap containing a list called inputs. This list should contain a single float32 (will be converted to int32 later) tensor of shape [batch], where each value represents an index into the embedding table. (By convention, all Steps that can be used with StackStep must store inputs in step_inputs.inputs[], but in this step it does not make sense for that list to have more than one tensor in it). padding: unused. state0: A NestedMap containing the state of previous tokens. - prev_ids: A Tensor containing the n previous token ids. [batch, num_prev_tokens]. Each row is the token ids at t-1, ..., t-n. Returns: Embedding vectors [batch, p.embedding_dim] and new state """ p = self.params # prepare token ids if p.include_current_token: ids = tf.concat([ tf.cast(step_inputs.inputs[0][:, None], tf.float32), tf.cast(state0.prev_ids, tf.float32) ], axis=-1) else: ids = state0.prev_ids # lookup embedding. ids.shape is [batch, num_tokens] ids = tf.cast(ids, tf.int32) embedding = self.emb.EmbLookup(theta.emb, ids) embedding = tf.reshape(embedding, [-1, p.embedding_dim]) # update state state1 = state0.copy() if p.num_prev_tokens > 0: state1.prev_ids = tf.concat([ tf.cast(step_inputs.inputs[0][:, None], tf.float32), tf.cast(state0.prev_ids[:, :-1], tf.float32) ], axis=-1) state1.prev_ids = tf.ensure_shape( state1.prev_ids, [None, p.num_prev_tokens], name='prev_ids_shape_validation') state1.embedding = embedding return py_utils.NestedMap(output=embedding), state1
def Decode(self, encoder_outputs, dec_input, state): """The decoder model. Args: encoder_outputs: a NestedMap containing the following fields: encoded - the encoding of the spelling of the target feature state - hidden state of the encoder output neighbor_spellings_encoded - encoding of neighbor spellings or tf.constant(0) neighbor_pronunciations_encoded - encoding of neighbor pronunciations or tf.constant(0) - initial prediction of [<s>] of shape [*, 1] dec_input: if not None, then use this instead of the dec_input in encoder_outputs. state: previous state of decoding. Returns: res: a NestedMap() containing predictions - of shape [*, output_vocab_size] state - updated hidden state attention_weights neighbor_attention_weights """ p = self.params context_vector, attention_weights = self.attention( state, encoder_outputs.encoded) (neighbor_context_vector, neighbor_attention_weights) = self._NeighborModelAttention( encoder_outputs, state) x = self.shared_out_emb(dec_input) # pylint: disable=not-callable if p.use_neighbors: x = tf.concat([context_vector, neighbor_context_vector, x], axis=-1) else: x = tf.concat([context_vector, x], axis=-1) output, state = self._gru_cell(x, state) x = self._fc(output) # If this fails then the checkpoint contains incorrect shapes or the # hparams are incompatible. No idea why TF doesn't check this anymore. x = tf.ensure_shape(x, [None, p.output_vocab_size]) res = py_utils.NestedMap() res.predictions = x res.state = state res.attention_weights = attention_weights res.neighbor_attention_weights = neighbor_attention_weights return res
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # TODO(rpang): avoid inspecting 'encoder_outputs'. source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): source_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: source_seq_lengths = tf.cast( tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def Cast(x): x = tf.ensure_shape(x, [p.batch_size, p.max_sequence_length]) if x.dtype.is_floating: x = tf.cast(x, py_utils.FPropDtype(p)) return x
def ShapeAndCast(x): x = tf.ensure_shape(x, (self.InfeedBatchSize(), p.source_max_length)) if x.dtype.is_floating: x = tf.cast(x, py_utils.FPropDtype(p)) return x
def _EnsureTgtShape(x): if x.dtype == tf.string: return tf.ensure_shape(x, [self._ScaledBatchSize()]) return tf.ensure_shape( x, [self._ScaledBatchSize(), self.params.target_max_length])
def _EnsureSrcShape(x): if x.dtype == tf.string: return tf.ensure_shape(x, [self._ScaledBatchSize()]) return tf.ensure_shape( x, [self._ScaledBatchSize(), self.params.source_max_length])
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.logical_and(cur_step < max_steps, tf.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def max_assignment(score: tf.Tensor, *, elementwise_upper_bound: tf.Tensor, row_sums: tf.Tensor, col_sums: tf.Tensor, epsilon: float = 0.1, num_iterations: int = 50, use_epsilon_scaling: bool = True): """Differentiable max assignment with margin and upper bound constraints. Args: score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k] denotes the weight if the assignment on this entry is non-zero. elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows, n_columns]. Each entry denotes the maximum value assignment[i, j, k] can take and must be a non-negative value. For example, upper_bound[i, j, k]=1.0 for binary assignment problem. row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint. The output assignment p[i, j, :] must sum to row_sums[i, j]. col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum constraint. The output assignment p[i, :, k] must sum to col_sums[i, k]. epsilon: the epsilon coefficient of entropy regularization. The value should be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may not make the assignment close enough to 0 or 1. num_iterations: the maximum number of iterations to perform. use_epsilon_scaling: whether to use epsilon scaling. In practice, the convergence of the iterative algorithm is much better if we start by solving the optimization with a larger epsilon value and re-use the solution (i.e. dual variables) for the instance with a smaller epsilon. This is called the epsilon scaling trick. See [Schmitzer 2019] (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if use_epsilon_scaling=True, after each iteration we decrease the running epsilon by a constant factor until it reaches the target epsilon value. We found this to work well for gradient backward propagation, while the original scaling trick doesn't. Returns: A tuple with the following values. - assignment: a 3D tensor of size [batch_size, n_rows, n_columns]. The output assignment. - used_iter: a scalar tensor indicating the number of iterations used. - eps: a scalar tensor indicating the stopping epsilon value. - delta: a scalar tensor indicating the stopping delta value (the relative change on the margins of assignment p in the last iteration). """ # Check if all shapes are correct score_shape = score.shape bsz = score_shape[0] n = score_shape[1] m = score_shape[2] score = tf.ensure_shape(score, [bsz, n, m]) elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound, [bsz, n, m]) row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1]) col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m]) # the total sum of row sums must be equal to total sum of column sums sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums, axis=2) sum_diff = tf.abs(sum_diff) tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff]) # Convert upper_bound constraint into another margin constraint # by adding auxiliary variables & scores. Tensor `a`, `b` and `c` # represent the margins (i.e. reduced sum) of 3 axes respectively. # max_row_sums = tf.reduce_sum(elementwise_upper_bound, axis=-1, keepdims=True) max_col_sums = tf.reduce_sum(elementwise_upper_bound, axis=-2, keepdims=True) score_ = tf.stack([score, tf.zeros_like(score)], axis=1) # (bsz, 2, n, m) a = tf.stack([row_sums, max_row_sums - row_sums], axis=1) # (bsz, 2, n, 1) b = tf.stack([col_sums, max_col_sums - col_sums], axis=1) # (bsz, 2, 1, m) c = tf.expand_dims(elementwise_upper_bound, axis=1) # (bsz, 1, n, m) # Clip log(0) to a large negative values -1e+36 to avoid # getting inf or NaN values in computation. Cannot use larger # values because float32 would use `-inf` automatically. # tf.Assert(tf.reduce_all(a >= 0), [a]) tf.Assert(tf.reduce_all(b >= 0), [b]) tf.Assert(tf.reduce_all(c >= 0), [c]) log_a = tf.maximum(tf.math.log(a), -1e+36) log_b = tf.maximum(tf.math.log(b), -1e+36) log_c = tf.maximum(tf.math.log(c), -1e+36) # Initialize the dual variables of margin constraints u = tf.zeros_like(a) v = tf.zeros_like(b) w = tf.zeros_like(c) eps = tf.constant(1.0 if use_epsilon_scaling else epsilon, dtype=score.dtype) epsilon = tf.constant(epsilon, dtype=score.dtype) def do_updates(cur_iter, eps, u, v, w): # pylint: disable=unused-argument # Epsilon scaling, i.e. gradually decreasing `eps` until it # reaches the target `epsilon` value cur_iter = tf.cast(cur_iter, u.dtype) scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85) eps = tf.maximum(epsilon, eps * scaling) score_div_eps = score_ / eps # Update u log_q_1 = score_div_eps + (w + v) / eps log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True) new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps # Update v log_q_2 = score_div_eps + (w + new_u) / eps log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True) new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps # Update w log_q_3 = score_div_eps + (new_u + new_v) / eps log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True) new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps return eps, new_u, new_v, new_w def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w): prev_sum_uvw = tf.stop_gradient((u + v + w) / eps) sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps) # Compute the relative changes on margins of P. # This will be used for stopping criteria. # Note the last update on w would guarantee the # margin constraint c is satisfied, so we don't # need to check it here. p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw)) p_a = tf.reduce_sum(p, axis=-1, keepdims=True) p_b = tf.reduce_sum(p, axis=-2, keepdims=True) delta_a = tf.abs(a - p_a) / (a + 1e-6) delta_b = tf.abs(b - p_b) / (b + 1e-6) new_delta = tf.reduce_max(delta_a) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b)) # Compute the relative changes on assignment solution P. # This will be used for stopping criteria. delta_p = tf.abs(tf.exp(prev_sum_uvw) - tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6) new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p)) return new_delta for cur_iter in tf.range(num_iterations): prev_eps, prev_u, prev_v, prev_w = eps, u, v, w eps, u, v, w = do_updates(cur_iter, eps, u, v, w) delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u, v, w) cur_iter = num_iterations assignment = tf.exp((score_ + u + v + w) / eps) assignment = assignment[:, 0] return assignment, cur_iter, eps, delta
def __init__(self, params): super().__init__(params) p = self.params (utt_ids, audio_document_ids, num_utterances_in_audio_document, tgt_ids, tgt_labels, tgt_paddings, src_frames, src_paddings), self._bucket_keys = self._BuildDataSource() self._sample_ids = utt_ids src_frames, src_paddings = self._MaybePadSourceInputs( src_frames, src_paddings) # We expect src_inputs to be of shape # [batch_size, num_frames, feature_dim, channels]. src_frames = tf.expand_dims(src_frames, axis=-1) # Convert target ids, labels, paddings, and weights from shape [batch_size, # 1, num_frames] to [batch_size, num_frames] tgt_ids = tf.squeeze(tgt_ids, axis=1) tgt_labels = tf.squeeze(tgt_labels, axis=1) tgt_paddings = tf.squeeze(tgt_paddings, axis=1) if p.pad_to_max_seq_length: assert p.source_max_length assert p.target_max_length if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit): # Set the input batch size as an int rather than a tensor. src_frames_shape = (self.InfeedBatchSize(), p.source_max_length, p.frame_size, 1) src_paddings_shape = (self.InfeedBatchSize(), p.source_max_length) tgt_shape = (self.InfeedBatchSize(), p.target_max_length) else: tf.logging.warning( 'Could not set static input shape since not all bucket batch sizes ' 'are the same:', p.bucket_batch_limit) src_frames_shape = None src_paddings_shape = None tgt_shape = None src_frames = py_utils.PadBatchDimension(src_frames, self.InfeedBatchSize(), 0) src_paddings = py_utils.PadBatchDimension(src_paddings, self.InfeedBatchSize(), 1) tgt_ids = py_utils.PadBatchDimension(tgt_ids, self.InfeedBatchSize(), 0) tgt_labels = py_utils.PadBatchDimension(tgt_labels, self.InfeedBatchSize(), 0) tgt_paddings = py_utils.PadBatchDimension(tgt_paddings, self.InfeedBatchSize(), 1) self._sample_ids = py_utils.PadBatchDimension( self._sample_ids, self.InfeedBatchSize(), type(self).PAD_INDEX) # For reasons I don't understand, the shape of self._sample_ids after the above is # [BatchSize, 1] rather than [BatchSize]. self._sample_ids = tf.squeeze(self._sample_ids, axis=1) self._sample_ids = tf.ensure_shape(self._sample_ids, self.InfeedBatchSize()) audio_document_ids = py_utils.PadBatchDimension( audio_document_ids, self.InfeedBatchSize(), type(self).PAD_INDEX) # For reasons I don't understand, the shape of audio_document_ids after the above is # [BatchSize, 1] rather than [BatchSize]. audio_document_ids = tf.squeeze(audio_document_ids, axis=1) audio_document_ids = tf.ensure_shape(audio_document_ids, self.InfeedBatchSize()) num_utterances_in_audio_document = py_utils.PadBatchDimension( num_utterances_in_audio_document, self.InfeedBatchSize(), type(self).PAD_INDEX) # For reasons I don't understand, the shape of num_utterances_in_audio_document after the above is # [BatchSize, 1] rather than [BatchSize]. num_utterances_in_audio_document = tf.squeeze( num_utterances_in_audio_document, axis=1) num_utterances_in_audio_document = tf.ensure_shape( num_utterances_in_audio_document, self.InfeedBatchSize()) src_frames = py_utils.PadSequenceDimension(src_frames, p.source_max_length, 0, shape=src_frames_shape) src_paddings = py_utils.PadSequenceDimension( src_paddings, p.source_max_length, 1, shape=src_paddings_shape) tgt_ids = py_utils.PadSequenceDimension(tgt_ids, p.target_max_length, 0, shape=tgt_shape) tgt_labels = py_utils.PadSequenceDimension(tgt_labels, p.target_max_length, 0, shape=tgt_shape) tgt_paddings = py_utils.PadSequenceDimension(tgt_paddings, p.target_max_length, 1, shape=tgt_shape) tgt = py_utils.NestedMap(ids=tgt_ids, labels=tgt_labels, paddings=tgt_paddings, weights=1.0 - tgt_paddings) src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings) self._tgt = tgt self._src = src self._audio_document_ids = audio_document_ids self._num_utterances_in_audio_document = num_utterances_in_audio_document
def DecodeWavPyFunc(input_bytes): sample_rate, audio = tf.py_function(read_wave_via_scipy, [input_bytes], [tf.int32, tf.float32]) sample_rate = tf.ensure_shape(sample_rate, []) audio = tf.ensure_shape(audio, [None, 1]) return sample_rate, audio