def _RelPositionBias(query, abs_pos_emb): """Computes relative position bias for general cases.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)] abs_pos_emb = tf.reverse(abs_pos_emb, [0]) # [B, N, T, L=2T-1] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Convert to [B, N, T, T] # part1 term_bd_left = term_bd[:, :, :, :t] term_bd_left = tf.reverse(term_bd_left, [2, 3]) term_bd_left = RelShift(term_bd_left) # [B, N, T, T] term_bd_left = tf.reverse(term_bd_left, [2, 3]) # part 2 term_bd_right = term_bd[:, :, :, t - 1:] # [B, N, T, T] term_bd_right = RelShift(term_bd_right) # [lower triangle] mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0) # stitching togather return tf.where(mask > 0, term_bd_left, term_bd_right)
def _GetExpertDist(self, theta, inputs, *args): """Get the task id from inputs tensors.""" # TODO(huangyp): support the more general case when batch size is not 1. # Input shape can be either [batch, length, dim] or [length, batch, dim] reshaped_inputs = tf.reshape(inputs, [-1, self.params.cond_dim]) if self.params.nonzeros_mean: per_example_emb = tf.reduce_sum(reshaped_inputs, 0) nonzeros = tf.cast( tf.math.count_nonzero(reshaped_inputs, 0), dtype=tf.float32) per_example_emb /= (nonzeros + 1e-10) else: per_example_emb = tf.reduce_mean(reshaped_inputs, 0) expert_dist = tf.nn.sigmoid(tf.einsum('i,ij->j', per_example_emb, theta.w)) return expert_dist
def _AttenLogits(query, key, abs_pos_emb, content_bias=None, positional_bias=None, is_causal=False): """Attention logits from ... Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of self attention with relative position embedding. Notice padding is supposed to be masked by the caller of this function. B: batch size T: sequence length N: num of attention heads. H: per-head attention dimension. Args: tensors of the following shapes: query: [B, T, N, H] key: [B, T, N, H] abs_pos_emb: [2T - 1, N, H]. The sinusoid positional embedding from https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative distance i - (T-1). content_bias: [N, H] or None positional_bias: [N, H] or None is_causal: A Python bool or a scalar bool Tensor. True for causal self attention. Returns: The attention logits tensor. [B, N, T, T] """ b, t, n, h = py_utils.GetShape(query) key = py_utils.HasShape(key, [b, t, n, h]) if content_bias is not None: content_bias = py_utils.HasShape(content_bias, [n, h]) else: content_bias = 0 if positional_bias is not None: positional_bias = py_utils.HasShape(positional_bias, [n, h]) else: positional_bias = 0 # [B, N, T, S=T] term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key) term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal) return term_ac + term_bd
def FProp(self, theta, inputs, *args): p = self.params with tf.name_scope(p.name) as scope: expert_dist = self._GetExpertDist(theta, inputs, *args) if not self.do_eval: summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist) # Excludes non-variable extra_theta like global_step. var_set = set([key for key, _ in self.body.vars.FlattenItems()]) values = [] for key, value in theta.body.FlattenItems(): if key in var_set and value is not None: # Weighted average for all variables created in the body layer. value = tf.einsum('i,i...->...', expert_dist, value) values.append(value) weighted_theta = theta.body.Pack(values) return self.body.FProp(weighted_theta, inputs, *args)
def _RelPositionBiasCausal(query, abs_pos_emb): """Computes relative position bias for causal self attention.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Retain only half and change order to [T-1, T-2, ... 0] # [T, N, H] abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t] # [B, N, T, L=T] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Perform shifting. term_bd = tf.reverse(term_bd, [2, 3]) term_bd = RelShift(term_bd) return tf.reverse(term_bd, [2, 3])
def ProjectInputSequence(self, theta, inputs): """Applies input projection for the entire sequence. Args: theta: a NestedMap of layer weights. Notably, it's expected to contain separate weight tensors for input and hidden state projections, for performance reasons, under the key 'wm_i' (input) and 'wm_h' inputs: A NestedMap with the following fields: - act: A list of Tensors of shape [seqlen, batch, input_dim]. Returns: A Tensor of shape [seqlen, batch, 4 * hidden_dim]. """ assert isinstance(inputs.act, list) if len(inputs.act) > 1: x = tf.concat(inputs.act, -1) else: x = inputs.act[0] # [T, B, 4 * H] proj_inputs = tf.einsum('TBD,DH->TBH', x, theta.wm_i) return proj_inputs
def EinsumBxycBzxBzyc(self, a, b, name=None): return tf.einsum('bxyc,bzx->bzyc', a, b, name=name)
def EinsumBxyBxBxy(self, a, b, name=None): return tf.einsum('bxy,bx->bxy', a, b, name=name)
def EinsumBxycBxBxyc(self, a, b, name=None): return tf.einsum('bxyc,bx->bxyc', a, b, name=name)
def EinsumBmtBmBt(self, a, b, name=None): return tf.einsum('bmt,bm->bt', a, b, name=name)
def EinsumBBmBm(self, a, b, name=None): return tf.einsum('b,bm->bm', a, b, name=name)