def padding_attention_mask(from_tensor, to_mask, *, dtype): """Create 3D attention mask from a 2D tensor mask. Args: from_tensor: A Tensor of shape [batch_size, heads, from_seq_length, to_seq_length]. to_mask: int32 Tensor of shape [batch_size, to_seq_length]. Returns: float Tensor of shape [batch_size, 1, from_seq_length, to_seq_length]. """ from_shape = shape_list(from_tensor) batch_size = from_shape[0] from_seq_length = from_shape[2] to_shape = shape_list(to_mask) to_seq_length = to_shape[1] to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]), dtype=dtype) # We don't assume that `from_tensor` is a mask (although it could be). We # don't actually care if we attend *from* padding tokens (only *to* padding) # tokens so we create a tensor of all ones. # # `broadcast_ones` = [batch_size, from_seq_length, 1] broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1], dtype=dtype) # Here we broadcast along two dimensions to create the mask. mask = broadcast_ones * to_mask mask = tf.expand_dims(mask, axis=1) return mask
def conv1d(x, scope, nf, *, w_init_stdev=0.02): with tf.variable_scope(scope): *start, nx = shape_list(x) w = tf.get_variable( 'w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) c = tf.reshape( tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, start + [nf]) return c
def left_to_right_attention_mask(from_tensor, to_tensor, *, dtype): """1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. Return shape:[1,1,nd,ns] """ assert from_tensor == to_tensor, "from_tensor and to_tensor must be same." _, _, nd, ns = shape_list(from_tensor) i = tf.range(nd)[:, None] j = tf.range(ns) m = i >= j - ns + nd m = tf.cast(m, dtype) m = tf.reshape(m, [1, 1, nd, ns]) return m
def merge_states(x): """Smash the last two dimensions of x into a single dimension.""" *start, a, b = shape_list(x) return tf.reshape(x, start + [a * b])
def split_states(x, n): """Reshape the last dimension of x into [n, x.shape[-1]/n].""" *start, m = shape_list(x) return tf.reshape(x, start + [n, m // n])