def _mask(x, batch_axis, axis, pos, max_amount): """ :param tf.Tensor x: (batch,time,feature) :param int batch_axis: :param int axis: :param tf.Tensor pos: (batch,) :param int|tf.Tensor max_amount: inclusive """ from returnn.tf.compat import v1 as tf ndim = x.get_shape().ndims n_batch = tf.shape(x)[batch_axis] dim = tf.shape(x)[axis] amount = tf.random_uniform(shape=(n_batch, ), minval=1, maxval=max_amount + 1, dtype=tf.int32) pos2 = tf.minimum(pos + amount, dim) idxs = tf.expand_dims(tf.range(0, dim), 0) # (1,dim) pos_bc = tf.expand_dims(pos, 1) # (batch,1) pos2_bc = tf.expand_dims(pos2, 1) # (batch,1) cond = tf.logical_and(tf.greater_equal(idxs, pos_bc), tf.less(idxs, pos2_bc)) # (batch,dim) if batch_axis > axis: cond = tf.transpose(cond) # (dim,batch) cond = tf.reshape(cond, [ tf.shape(x)[i] if i in (batch_axis, axis) else 1 for i in range(ndim) ]) from TFUtil import where_bc x = where_bc(cond, 0.0, x) return x
def summary(name, x): """ :param str name: :param tf.Tensor x: (batch,time,feature) """ from returnn.tf.compat import v1 as tf # tf.summary.image wants [batch_size, height, width, channels], # we have (batch, time, feature). img = tf.expand_dims(x, axis=3) # (batch,time,feature,1) img = tf.transpose(img, [0, 2, 1, 3]) # (batch,feature,time,1) tf.summary.image(name, img, max_outputs=10) tf.summary.scalar("%s_max_abs" % name, tf.reduce_max(tf.abs(x))) mean = tf.reduce_mean(x) tf.summary.scalar("%s_mean" % name, mean) stddev = tf.sqrt(tf.reduce_mean(tf.square(x - mean))) tf.summary.scalar("%s_stddev" % name, stddev) tf.summary.histogram("%s_hist" % name, tf.reduce_max(tf.abs(x), axis=2))