def get_masked():
     x_masked = x
     x_masked = _random_mask(
         x_masked, batch_axis=data.batch_dim_axis, axis=data.time_dim_axis,
         min_num=min_frame_masks,
         max_num=tf.maximum(tf.shape(x)[data.time_dim_axis] // mask_each_n_frames, min_frame_masks),
         max_dims=max_frames_per_mask)
     x_masked = _random_mask(
         x_masked, batch_axis=data.batch_dim_axis, axis=data.feature_dim_axis,
         min_num=min_feature_masks, max_num=max_feature_masks,
         max_dims=max_features_per_mask)
     return x_masked
Esempio n. 2
0
 def get_masked():
     x_masked = x
     x_masked = random_mask(x_masked,
                            axis=1,
                            min_num=1,
                            max_num=tf.maximum(tf.shape(x)[1] // 100, 1),
                            max_dims=20)
     x_masked = random_mask(x_masked,
                            axis=2,
                            min_num=1,
                            max_num=2,
                            max_dims=40 // 5)
     return x_masked
 def get_masked():
     x_masked = x
     x_masked = random_mask(
         x_masked,
         batch_axis=data.batch_dim_axis,
         axis=data.time_dim_axis,
         min_num=step1 + step2,
         max_num=tf.maximum(tf.shape(x)[data.time_dim_axis] // 100, 2) *
         (1 + step1 + step2 * 2),
         max_dims=20 // time_factor)
     x_masked = random_mask(x_masked,
                            batch_axis=data.batch_dim_axis,
                            axis=data.feature_dim_axis,
                            min_num=step1 + step2,
                            max_num=2 + step1 + step2 * 2,
                            max_dims=data.dim // 5)
     return x_masked
def get_contrastive_loss_mask(source, **_kwargs):
    def _random_mask(x, axis, min_num, max_num, max_dims):
        from returnn.tf.compat import v1 as tf
        n_batch = tf.shape(x)[0]
        num = tf.random_uniform(shape=(n_batch, ),
                                minval=min_num,
                                maxval=max_num + 1,
                                dtype=tf.int32)
        # https://github.com/tensorflow/tensorflow/issues/9260
        # https://timvieira.github.io/blog/post/2014/08/01/gumbel-max-trick-and-weighted-reservoir-sampling/
        z = -tf.log(
            -tf.log(tf.random_uniform((n_batch, tf.shape(x)[axis]), 0, 1)))
        _, indices = tf.nn.top_k(z, tf.reduce_max(num))
        # indices should be sorted, and of shape (batch,num), entries (int32) in [0,dim)
        # indices = tf.Print(indices, ["indices", indices, tf.shape(indices)])

        res_mask = tf.zeros(shape=[n_batch, tf.shape(x)[axis]],
                            dtype=tf.bool)  # all False
        _, res_mask = tf.while_loop(
            cond=lambda i, _: tf.less(i, tf.reduce_max(num)),
            body=lambda i, res_mask:
            (i + 1,
             tf.where(
                 tf.less(i, num),
                 tf.math.logical_or(
                     res_mask,
                     _get_mask(
                         x, axis=axis, pos=indices[:, i], max_amount=max_dims)
                 ), res_mask)),
            loop_vars=(0, res_mask))
        return res_mask  # (batch,dim)

    from returnn.tf.compat import v1 as tf
    data = source(0, as_data=True, auto_convert=False)
    assert (data.batch_dim_axis, data.time_dim_axis) == (0, 1)
    x = data.placeholder
    mask = _random_mask(x,
                        axis=1,
                        min_num=1,
                        max_num=tf.maximum(tf.shape(x)[1] // 100, 1),
                        max_dims=20)
    return mask