def _AttenLogits(query, key, abs_pos_emb, content_bias=None, positional_bias=None, is_causal=False): """Attention logits from ... Transformer-XL(https://arxiv.org/pdf/1901.02860.pdf, section 3.3) version of self attention with relative position embedding. Notice padding is supposed to be masked by the caller of this function. B: batch size T: sequence length N: num of attention heads. H: per-head attention dimension. Args: tensors of the following shapes: query: [B, T, N, H] key: [B, T, N, H] abs_pos_emb: [2T - 1, N, H]. The sinusoid positional embedding from https://arxiv.org/abs/1706.03762. abs_pos_emb[i] is the emb of relative distance i - (T-1). content_bias: [N, H] or None positional_bias: [N, H] or None is_causal: A Python bool or a scalar bool Tensor. True for causal self attention. Returns: The attention logits tensor. [B, N, T, T] """ b, t, n, h = py_utils.GetShape(query) key = py_utils.HasShape(key, [b, t, n, h]) if content_bias is not None: content_bias = py_utils.HasShape(content_bias, [n, h]) else: content_bias = 0 if positional_bias is not None: positional_bias = py_utils.HasShape(positional_bias, [n, h]) else: positional_bias = 0 # [B, N, T, S=T] term_ac = tf.einsum('BTNH,BSNH->BNTS', query + content_bias, key) term_bd = RelPositionBias(query + positional_bias, abs_pos_emb, is_causal) return term_ac + term_bd
def FProp(self, theta, *args): """Runs p.repeat copies of self.body.FProp independently. Args: theta: Layer model parameters. The shape of each variable in theta is always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the i-th copy of self.body. *args: Input arguments. The shape of each tensor in args is always [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs to the i-th copy of self.body.FProp. Returns: The accumulated output_tensors. Each tensor t in the return has the shape [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the return tuple of the i-th self.body.FProp. """ p = self.params for arg in args: if arg is not None: arg = py_utils.HasShape(arg, [p.repeat], ndims=1) theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) inputs = py_utils.NestedMap(theta=theta_stack, args=list(args)) # Infer out_shapes from FPropMeta. out_shapes = self._InferOutShapes(args) def _CellFn(unused_theta, unused_state0, inputs): """Recurrent cell function wrapper of body.FProp.""" # Sets shapes for both theta and inputs to self.body.FProp. for dst, src in zip(inputs.args + inputs.theta.Flatten(), list(args) + theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self.body.FProp(inputs.theta, *inputs.args) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(out_shapes) # Passes fprop outputs to the next layer through state. state1 = py_utils.NestedMap(outputs=list(fprop_outputs)) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Initiate state0 with inferred output shapes. state0 = py_utils.NestedMap( outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes]) # Runs body.FProp p.repeat times using Recurrent. acc_states, _ = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=inputs, cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = tuple(acc_states.outputs) for out_idx in range(len(output_tensors)): output_tensors[out_idx].set_shape( tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list())) return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
def _RelPositionBias(query, abs_pos_emb): """Computes relative position bias for general cases.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Change to [T-1, T-2, ... 0, -1, -2, ... -(T-2), -(T-1)] abs_pos_emb = tf.reverse(abs_pos_emb, [0]) # [B, N, T, L=2T-1] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Convert to [B, N, T, T] # part1 term_bd_left = term_bd[:, :, :, :t] term_bd_left = tf.reverse(term_bd_left, [2, 3]) term_bd_left = RelShift(term_bd_left) # [B, N, T, T] term_bd_left = tf.reverse(term_bd_left, [2, 3]) # part 2 term_bd_right = term_bd[:, :, :, t - 1:] # [B, N, T, T] term_bd_right = RelShift(term_bd_right) # [lower triangle] mask = tf.linalg.band_part(tf.ones_like(term_bd_right), -1, 0) # stitching togather return tf.where(mask > 0, term_bd_left, term_bd_right)
def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def FProp(self, theta, inputs, paddings): """Apply convolution to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor, expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. """ p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]), py_utils.assert_shape_match( tf.shape(inputs), tf.concat([ tf.shape(paddings), [-1, symbolic.ToStatic(self.input_channels)] ], 0)) ], inputs) def _ApplyPadding(tensor_in, padding_in): padding_expanded = tf.expand_dims(tf.expand_dims(padding_in, -1), -1) return tensor_in * (1.0 - padding_expanded) # Zeroing out padded inputs. inputs = _ApplyPadding(inputs, paddings) # Apply conv on 'inputs'. out = self._ApplyConv(theta, inputs) if p.partial_conv: out = self._RescaleBoundary(out, paddings) # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1. # But there's likely no real problems. Trying to set it gives an error: # pooling with SAME padding is not implemented for dilation_rate > 1. # NOTE: we use window=p.filter_stride[0] to be compatible with legacy # implementation. Consider updating it to be the actual shape. conv_padding = ComputeConvOutputPadding( paddings, window=p.filter_stride[0], stride=p.filter_stride[0]) # Assuming padded nodes will be properly zero-ed out if necessary by # sub-sequent layers. # out = _ApplyPadding(out, conv_padding) out = py_utils.HasShape( out, symbolic.ToStatic(self.OutShape(tf.shape(inputs)))) return out, conv_padding
def FProp(self, theta, inputs, paddings): """Apply global spatial pooling to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor. It is expected to be of shape [batch, time]. Defaults to None, which means there no paddings. Returns: outputs, out_paddings pair. - outputs: has shape [batch, 1, 1, channel]. - out_paddings: None or has shape [batch, 1]. """ p = self.params assert p.pooling_type in ['MAX', 'AVG'], p.pooling_type b, t, f = py_utils.GetShape(inputs, ndims=3) if paddings is not None: paddings = py_utils.HasShape(paddings, [b, t]) if paddings is not None: mask = 1.0 - paddings[..., tf.newaxis, tf.newaxis] else: mask = tf.ones([b, t, 1, 1], p.dtype) if p.pooling_type == 'AVG': global_sum = tf.reduce_sum(inputs * mask, axis=[1, 2], keepdims=True) f = tf.cast(tf.convert_to_tensor(f), p.dtype) count = f * tf.reduce_sum(mask, axis=[1, 2], keepdims=True) out_feature = global_sum / tf.maximum(1.0, count) elif p.pooling_type == 'MAX': large_negative = ( tf.ones_like(inputs) * p.dtype.max * tf.constant(-0.7, dtype=p.dtype)) padded_inputs = tf.where_v2(mask > 0.0, inputs, large_negative) out_feature = tf.reduce_max(padded_inputs, axis=[1, 2], keepdims=True) if paddings is None: out_paddings = None else: out_paddings = tf.reduce_min(paddings, axis=1, keepdims=True) out_feature *= 1.0 - out_paddings[..., tf.newaxis, tf.newaxis] return out_feature, out_paddings
def _RelPositionBiasCausal(query, abs_pos_emb): """Computes relative position bias for causal self attention.""" _, t, n, h = py_utils.GetShape(query) abs_pos_emb = py_utils.HasShape(abs_pos_emb, [2 * t - 1, n, h]) # abs_pos_emb is [-(T-1), -(T-2), ... 0, 1, 2, ... T-1] # Retain only half and change order to [T-1, T-2, ... 0] # [T, N, H] abs_pos_emb = tf.reverse(abs_pos_emb, [0])[:t] # [B, N, T, L=T] term_bd = tf.einsum('BTNH,LNH->BNTL', query, abs_pos_emb) # Perform shifting. term_bd = tf.reverse(term_bd, [2, 3]) term_bd = RelShift(term_bd) return tf.reverse(term_bd, [2, 3])
def RelShift(x): """Performs relative shift on 4D tensor (first 2 axis are batching dims). Given input of shape [?, ?, W, W], this does "relative shifting" for the last two dims, s.t. output[b, n, i, j] = 0 if i > j else input[b, n, i, j-i] Args: x: A Tensor of shape [?, ?, W, W] Returns: A Tensor of the same shape as input with its content shifted (as described above). """ b, n, w, _ = py_utils.GetShape(x) x = py_utils.HasShape(x, [-1, -1, w, w]) x = tf.pad(x, ((0, 0), (0, 0), (0, 0), (0, 1))) x = tf.reshape(x, [b, n, w + 1, w]) x = x[:, :, :w, :] return x
def FProp(self, theta, inputs, paddings, state0=None, segment_id=None): """Computes LSTM forward pass. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A single tensor or a tuple of tensors with cardinality equal to rnn_cell.inputs_arity. For every input tensor, the first dimension is assumed to be time, second dimension batch, and third dimension depth. paddings: A tensor. First dim is time, second dim is batch, and third dim is expected to be 1. state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to the cell's zero-state. segment_id: A tensor to support packed inputs. First dim is time, second dim is batch, and third dim is expected to be 1. Returns: A tensor of [time, batch, dims]. The final recurrent state. """ p = self.params rcell = self.cell assert isinstance(rcell, (rnn_cell.RNNCell)) if not isinstance(inputs, (list, tuple)): inputs = [inputs] # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular # LSTM baseline. # Keeping slicing within the loop gives only < 3% speedup. cell_theta = theta.cell.copy() num_input_nodes = p.cell.num_input_nodes cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :] cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :] tf.logging.vlog(1, 'cell_theta: %r', cell_theta) if p.packed_input: assert segment_id is not None reset_mask = rnn_layers.GeneratePackedInputResetMask( segment_id, is_reverse=False) reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings)) else: reset_mask = tf.zeros_like(paddings) if p.reverse: inputs = [tf.reverse(x, [0]) for x in inputs] paddings = tf.reverse(paddings, [0]) reset_mask = tf.reverse(reset_mask, [0]) if not state0: batch_size = py_utils.GetShape(paddings)[1] state0 = rcell.zero_state(cell_theta, batch_size) # [T, B, H] proj_inputs = rcell.ProjectInputSequence( cell_theta, py_utils.NestedMap(act=inputs)) proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs, padding=paddings, reset_mask=reset_mask) acc_state, final_state = recurrent.Recurrent( theta=cell_theta, state0=state0, inputs=proj_inputs, cell_fn=rcell.FPropWithProjectedInput, cell_type=rcell.layer_type, accumulator_layer=self, allow_implicit_capture=p.allow_implicit_capture) act = rcell.GetOutput(acc_state) if p.reverse: act = tf.reverse(act, [0]) return act, final_state