def _mask(x, batch_axis, axis, pos, max_amount):
    """
    :param tf.Tensor x: (batch,time,feature)
    :param int batch_axis:
    :param int axis:
    :param tf.Tensor pos: (batch,)
    :param int|tf.Tensor max_amount: inclusive
    """
    from returnn.tf.compat import v1 as tf
    ndim = x.get_shape().ndims
    n_batch = tf.shape(x)[batch_axis]
    dim = tf.shape(x)[axis]
    amount = tf.random_uniform(shape=(n_batch, ),
                               minval=1,
                               maxval=max_amount + 1,
                               dtype=tf.int32)
    pos2 = tf.minimum(pos + amount, dim)
    idxs = tf.expand_dims(tf.range(0, dim), 0)  # (1,dim)
    pos_bc = tf.expand_dims(pos, 1)  # (batch,1)
    pos2_bc = tf.expand_dims(pos2, 1)  # (batch,1)
    cond = tf.logical_and(tf.greater_equal(idxs, pos_bc),
                          tf.less(idxs, pos2_bc))  # (batch,dim)
    if batch_axis > axis:
        cond = tf.transpose(cond)  # (dim,batch)
    cond = tf.reshape(cond, [
        tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim)
    ])
    from TFUtil import where_bc
    x = where_bc(cond, 0.0, x)
    return x
def _get_mask(x, axis, pos, max_amount):
    """
  :param tf.Tensor x: (batch,time,feature)
  :param int axis:
  :param tf.Tensor pos: (batch,)
  :param int max_amount: inclusive
  """
    from returnn.tf.compat import v1 as tf
    n_batch = tf.shape(x)[0]
    dim = tf.shape(x)[axis]
    amount = tf.random_uniform(shape=(n_batch, ),
                               minval=1,
                               maxval=max_amount + 1,
                               dtype=tf.int32)
    pos2 = tf.minimum(pos + amount, dim)
    idxs = tf.expand_dims(tf.range(0, dim), 0)  # (1,dim)
    pos_bc = tf.expand_dims(pos, 1)  # (batch,1)
    pos2_bc = tf.expand_dims(pos2, 1)  # (batch,1)
    cond = tf.logical_and(tf.greater_equal(idxs, pos_bc),
                          tf.less(idxs, pos2_bc))  # (batch,dim)
    return cond
def summary(name, x):
    """
  :param str name:
  :param tf.Tensor x: (batch,time,feature)
  """
    from returnn.tf.compat import v1 as tf
    # tf.summary.image wants [batch_size, height,  width, channels],
    # we have (batch, time, feature).
    img = tf.expand_dims(x, axis=3)  # (batch,time,feature,1)
    img = tf.transpose(img, [0, 2, 1, 3])  # (batch,feature,time,1)
    tf.summary.image(name, img, max_outputs=10)
    tf.summary.scalar("%s_max_abs" % name, tf.reduce_max(tf.abs(x)))
    mean = tf.reduce_mean(x)
    tf.summary.scalar("%s_mean" % name, mean)
    stddev = tf.sqrt(tf.reduce_mean(tf.square(x - mean)))
    tf.summary.scalar("%s_stddev" % name, stddev)
    tf.summary.histogram("%s_hist" % name, tf.reduce_max(tf.abs(x), axis=2))