def targetb_recomb_recog(layer, batch_dim, scores_in, scores_base, base_beam_in, end_flags, **kwargs): """ :param ChoiceLayer layer: :param tf.Tensor batch_dim: scalar :param tf.Tensor scores_base: (batch,base_beam_in,1). existing beam scores :param tf.Tensor scores_in: (batch,base_beam_in,dim). log prob frame distribution :param tf.Tensor end_flags: (batch,base_beam_in) :param tf.Tensor base_beam_in: int32 scalar, 1 or prev beam size :rtype: tf.Tensor :return: (batch,base_beam_in,dim), combined scores """ from returnn.tf.compat import v1 as tf from returnn.tf.util.basic import where_bc, nd_indices, tile_transposed from returnn.datasets.generating import Vocabulary dim = layer.output.dim prev_str = layer.explicit_search_sources[0].output # [B*beam], str prev_str_t = tf.reshape(prev_str.placeholder, (batch_dim, -1))[:, :base_beam_in] prev_out = layer.explicit_search_sources[1].output # [B*beam], int32 prev_out_t = tf.reshape(prev_out.placeholder, (batch_dim, -1))[:, :base_beam_in] dataset = get_dataset("train") vocab = Vocabulary.create_vocab(**dataset["bpe"]) labels = vocab.labels # bpe labels ("@@" at end, or not), excluding blank labels = [(l + " ").replace("@@ ", "").encode("utf8") for l in labels] + [b""] # Pre-filter approx (should be much faster), sum approx (better). scores_base = tf.reshape( get_filtered_score_cpp( prev_str_t, tf.reshape(scores_base, (batch_dim, base_beam_in)), labels), (batch_dim, base_beam_in, 1)) scores = scores_in + scores_base # (batch,beam,dim) # Mask -> max approx, in all possible options, slow. #mask = get_score_mask_cpp(prev_str_t, prev_out_t, scores, labels) #masked_scores = where_bc(mask, scores, float("-inf")) # Sum approx in all possible options, slow. #masked_scores = get_new_score_cpp(prev_str_t, prev_out_t, scores, labels) #scores = where_bc(end_flags[:,:,None], scores, masked_scores) return scores
def _mask(x, batch_axis, axis, pos, max_amount): """ :param tf.Tensor x: (batch,time,feature) :param int batch_axis: :param int axis: :param tf.Tensor pos: (batch,) :param int|tf.Tensor max_amount: inclusive """ from returnn.tf.compat import v1 as tf ndim = x.get_shape().ndims n_batch = tf.shape(x)[batch_axis] dim = tf.shape(x)[axis] amount = tf.random_uniform(shape=(n_batch, ), minval=1, maxval=max_amount + 1, dtype=tf.int32) pos2 = tf.minimum(pos + amount, dim) idxs = tf.expand_dims(tf.range(0, dim), 0) # (1,dim) pos_bc = tf.expand_dims(pos, 1) # (batch,1) pos2_bc = tf.expand_dims(pos2, 1) # (batch,1) cond = tf.logical_and(tf.greater_equal(idxs, pos_bc), tf.less(idxs, pos2_bc)) # (batch,dim) if batch_axis > axis: cond = tf.transpose(cond) # (dim,batch) cond = tf.reshape(cond, [ tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim) ]) from TFUtil import where_bc x = where_bc(cond, 0.0, x) return x
def _mask(x, axis, pos, max_amount): from returnn.tf.compat import v1 as tf ndim = x.get_shape().ndims cond = _get_mask(x, axis, pos, max_amount) cond = tf.reshape( cond, [tf.shape(x)[i] if i in (0, axis) else 1 for i in range(ndim)]) from TFUtil import where_bc x = where_bc(cond, 0.0, x) return x