def _random_mask(x, axis, min_num, max_num, max_dims):
        from returnn.tf.compat import v1 as tf
        n_batch = tf.shape(x)[0]
        num = tf.random_uniform(shape=(n_batch, ),
                                minval=min_num,
                                maxval=max_num + 1,
                                dtype=tf.int32)
        # https://github.com/tensorflow/tensorflow/issues/9260
        # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
        z = -tf.log(
            -tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1)))
        _, indices = tf.nn.top_k(z, tf.reduce_max(num))
        # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim)
        # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)])

        res_mask = tf.zeros(shape=[n_batch, tf.shape(x)[axis]],
                            dtype=tf.bool)  # all False
        _, res_mask = tf.while_loop(
            cond=lambda i, _: tf.less(i, tf.reduce_max(num)),
            body=lambda i, res_mask:
            (i + 1,
             tf.where(
                 tf.less(i, num),
                 tf.math.logical_or(
                     res_mask,
                     _get_mask(
                         x, axis=axis, pos=indices[:, i], max_amount=max_dims)
                 ), res_mask)),
            loop_vars=(0, res_mask))
        return res_mask  # (batch,dim)
def random_mask(x, axis, min_num, max_num, max_dims):
    """
  :param tf.Tensor x: (batch,time,feature)
  :param int axis:
  :param int|tf.Tensor min_num:
  :param int|tf.Tensor max_num: inclusive
  :param int max_dims: inclusive
  """
    from returnn.tf.compat import v1 as tf
    n_batch = tf.shape(x)[0]
    num = tf.random_uniform(shape=(n_batch, ),
                            minval=min_num,
                            maxval=max_num + 1,
                            dtype=tf.int32)
    # https://github.com/tensorflow/tensorflow/issues/9260
    # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
    z = -tf.log(-tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1)))
    _, indices = tf.nn.top_k(z, tf.reduce_max(num))
    # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim)
    # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)])
    _, x = tf.while_loop(
        cond=lambda i, _: tf.less(i, tf.reduce_max(num)),
        body=lambda i, x:
        (i + 1,
         tf.where(tf.less(i, num),
                  _mask(x, axis=axis, pos=indices[:, i], max_amount=max_dims),
                  x)),
        loop_vars=(0, x))
    return x
Exemplo n.º 3
0
 def _random_mask(x, axis, min_num, max_num, max_dims):
     from returnn.tf.compat import v1 as tf
     n_batch = tf.shape(x)[0]
     num = tf.random_uniform(shape=(n_batch, ),
                             minval=min_num,
                             maxval=max_num + 1,
                             dtype=tf.int32)
     # https://github.com/tensorflow/tensorflow/issues/9260
     # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
     z = -tf.log(
         -tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1)))
     _, indices = tf.nn.top_k(z, tf.reduce_max(num))
     # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim)
     # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)])
     return _get_mask(x, axis=axis, pos=indices[:, 0], max_amount=max_dims)