def targetb_recomb_recog(layer, batch_dim, scores_in, scores_base,
                         base_beam_in, end_flags, **kwargs):
    """
   :param ChoiceLayer layer:
   :param tf.Tensor batch_dim: scalar
   :param tf.Tensor scores_base: (batch,base_beam_in,1). existing beam scores
   :param tf.Tensor scores_in: (batch,base_beam_in,dim). log prob frame distribution
   :param tf.Tensor end_flags: (batch,base_beam_in)
   :param tf.Tensor base_beam_in: int32 scalar, 1 or prev beam size
   :rtype: tf.Tensor
   :return: (batch,base_beam_in,dim), combined scores
   """
    from returnn.tf.compat import v1 as tf
    from returnn.tf.util.basic import where_bc, nd_indices, tile_transposed
    from returnn.datasets.generating import Vocabulary

    dim = layer.output.dim

    prev_str = layer.explicit_search_sources[0].output  # [B*beam], str
    prev_str_t = tf.reshape(prev_str.placeholder,
                            (batch_dim, -1))[:, :base_beam_in]
    prev_out = layer.explicit_search_sources[1].output  # [B*beam], int32
    prev_out_t = tf.reshape(prev_out.placeholder,
                            (batch_dim, -1))[:, :base_beam_in]

    dataset = get_dataset("train")
    vocab = Vocabulary.create_vocab(**dataset["bpe"])
    labels = vocab.labels  # bpe labels ("@@" at end, or not), excluding blank
    labels = [(l + " ").replace("@@ ", "").encode("utf8")
              for l in labels] + [b""]

    # Pre-filter approx (should be much faster), sum approx (better).
    scores_base = tf.reshape(
        get_filtered_score_cpp(
            prev_str_t, tf.reshape(scores_base, (batch_dim, base_beam_in)),
            labels), (batch_dim, base_beam_in, 1))

    scores = scores_in + scores_base  # (batch,beam,dim)

    # Mask -> max approx, in all possible options, slow.
    #mask = get_score_mask_cpp(prev_str_t, prev_out_t, scores, labels)
    #masked_scores = where_bc(mask, scores, float("-inf"))
    # Sum approx in all possible options, slow.
    #masked_scores = get_new_score_cpp(prev_str_t, prev_out_t, scores, labels)

    #scores = where_bc(end_flags[:,:,None], scores, masked_scores)

    return scores
def _mask(x, batch_axis, axis, pos, max_amount):
    """
    :param tf.Tensor x: (batch,time,feature)
    :param int batch_axis:
    :param int axis:
    :param tf.Tensor pos: (batch,)
    :param int|tf.Tensor max_amount: inclusive
    """
    from returnn.tf.compat import v1 as tf
    ndim = x.get_shape().ndims
    n_batch = tf.shape(x)[batch_axis]
    dim = tf.shape(x)[axis]
    amount = tf.random_uniform(shape=(n_batch, ),
                               minval=1,
                               maxval=max_amount + 1,
                               dtype=tf.int32)
    pos2 = tf.minimum(pos + amount, dim)
    idxs = tf.expand_dims(tf.range(0, dim), 0)  # (1,dim)
    pos_bc = tf.expand_dims(pos, 1)  # (batch,1)
    pos2_bc = tf.expand_dims(pos2, 1)  # (batch,1)
    cond = tf.logical_and(tf.greater_equal(idxs, pos_bc),
                          tf.less(idxs, pos2_bc))  # (batch,dim)
    if batch_axis > axis:
        cond = tf.transpose(cond)  # (dim,batch)
    cond = tf.reshape(cond, [
        tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim)
    ])
    from TFUtil import where_bc
    x = where_bc(cond, 0.0, x)
    return x
def _mask(x, axis, pos, max_amount):
    from returnn.tf.compat import v1 as tf
    ndim = x.get_shape().ndims
    cond = _get_mask(x, axis, pos, max_amount)
    cond = tf.reshape(
        cond, [tf.shape(x)[i] if i in (0, axis) else 1 for i in range(ndim)])
    from TFUtil import where_bc
    x = where_bc(cond, 0.0, x)
    return x