def grow_topk(i, alive_seq, alive_log_probs, states): """Inner beam search loop.""" flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) # (batch_size * beam_size, decoded_length) if states: flat_states = nest.map_structure(utils.merge_beam_dim, states) flat_logits, flat_states = symbols_to_logits_fn( flat_ids, i, flat_states) states = nest.map_structure( lambda t: utils.unmerge_beam_dim(t, batch_size, beam_size), flat_states) else: flat_logits = symbols_to_logits_fn(flat_ids) logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) candidate_log_probs = utils.log_prob_from_logits(logits) log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha) curr_scores = log_probs / length_penalty flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size]) topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids // vocab_size topk_ids %= vocab_size # Unflatten the ids batch_pos = utils.compute_batch_indices(batch_size, beam_size * 2) topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2) topk_seq = tf.gather_nd(alive_seq, topk_coordinates) if states: states = nest.map_structure( lambda state: tf.gather_nd(state, topk_coordinates), states) topk_seq = tf.concat( [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_finished = tf.equal(topk_ids, eos_id) return topk_seq, topk_log_probs, topk_scores, topk_finished, states
def call(self, inputs: list, **kwargs) -> typing.Any: """ The computation logic of DynamicPoolingLayer. :param inputs: two input tensors. """ self._validate_dpool_size() x, dpool_index = inputs dpool_shape = tf.shape(dpool_index) batch_index_one = tf.expand_dims( tf.expand_dims(tf.range(dpool_shape[0]), axis=-1), axis=-1) batch_index = tf.expand_dims( tf.tile(batch_index_one, [1, self._msize1, self._msize2]), axis=-1) dpool_index_ex = tf.concat([batch_index, dpool_index], axis=3) x_expand = tf.gather_nd(x, dpool_index_ex) stride1 = self._msize1 // self._psize1 stride2 = self._msize2 // self._psize2 x_pool = tf.nn.max_pool(x_expand, [1, stride1, stride2, 1], [1, stride1, stride2, 1], "VALID") return x_pool
def transform_preprocess(labels=None, blank_index=None, num_class=None): ''' Ensure that the value of blank_index is in a reasonable range, and transform the DenseTensor labels to a SparseTensor ''' if blank_index is None or blank_index < 0: raise ValueError('blank_index must be greater than or equal to zero') if not num_class is None and blank_index > (num_class - 1): raise ValueError('blank_index must be less than or equal to num_class - 1') if labels is None: return None if not isinstance(labels, tf.SparseTensor): labels = tf.cast(labels, tf.int32) labels_idx = tf.where(tf.not_equal(labels, 0)) labels_values = tf.gather_nd(labels, labels_idx) labels_shape = tf.cast(tf.shape(labels), dtype=tf.int64) labels = tf.SparseTensor( indices=labels_idx, values=labels_values, dense_shape=labels_shape) return labels
def gather(tensor, name): return tf.gather_nd(tensor, top_coordinates, name=(prefix + name))