def _random_mask(x, axis, min_num, max_num, max_dims): from returnn.tf.compat import v1 as tf n_batch = tf.shape(x)[0] num = tf.random_uniform(shape=(n_batch, ), minval=min_num, maxval=max_num + 1, dtype=tf.int32) # https://github.com/tensorflow/tensorflow/issues/9260 # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/ z = -tf.log( -tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1))) _, indices = tf.nn.top_k(z, tf.reduce_max(num)) # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim) # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)]) res_mask = tf.zeros(shape=[n_batch, tf.shape(x)[axis]], dtype=tf.bool) # all False _, res_mask = tf.while_loop( cond=lambda i, _: tf.less(i, tf.reduce_max(num)), body=lambda i, res_mask: (i + 1, tf.where( tf.less(i, num), tf.math.logical_or( res_mask, _get_mask( x, axis=axis, pos=indices[:, i], max_amount=max_dims) ), res_mask)), loop_vars=(0, res_mask)) return res_mask # (batch,dim)
def random_mask(x, axis, min_num, max_num, max_dims): """ :param tf.Tensor x: (batch,time,feature) :param int axis: :param int|tf.Tensor min_num: :param int|tf.Tensor max_num: inclusive :param int max_dims: inclusive """ from returnn.tf.compat import v1 as tf n_batch = tf.shape(x)[0] num = tf.random_uniform(shape=(n_batch, ), minval=min_num, maxval=max_num + 1, dtype=tf.int32) # https://github.com/tensorflow/tensorflow/issues/9260 # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/ z = -tf.log(-tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1))) _, indices = tf.nn.top_k(z, tf.reduce_max(num)) # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim) # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)]) _, x = tf.while_loop( cond=lambda i, _: tf.less(i, tf.reduce_max(num)), body=lambda i, x: (i + 1, tf.where(tf.less(i, num), _mask(x, axis=axis, pos=indices[:, i], max_amount=max_dims), x)), loop_vars=(0, x)) return x
def _random_mask(x, axis, min_num, max_num, max_dims): from returnn.tf.compat import v1 as tf n_batch = tf.shape(x)[0] num = tf.random_uniform(shape=(n_batch, ), minval=min_num, maxval=max_num + 1, dtype=tf.int32) # https://github.com/tensorflow/tensorflow/issues/9260 # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/ z = -tf.log( -tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1))) _, indices = tf.nn.top_k(z, tf.reduce_max(num)) # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim) # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)]) return _get_mask(x, axis=axis, pos=indices[:, 0], max_amount=max_dims)