def _fast_matrix_shift(x, funnel_factor=1, is_upsampling=False): """Fast matrix shift.""" if funnel_factor == 1 and not is_upsampling: shift = 1 batch_size, n_head = x.shape[0], x.shape[1] queries_len, keys_len = x.shape[2], x.shape[3] zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift)) x = jnp.concatenate([zero_pad, x], axis=3) x = x.reshape(batch_size, n_head, keys_len + shift, queries_len) x = x[:, :, shift:, :] return x if is_upsampling: k = funnel_factor shift = 1 else: k = 1 shift = funnel_factor bsz, n_head = x.shape[0], x.shape[1] qlen, klen = x.shape[2], (x.shape[3] + 1) // 2 zero_pad = jnp.zeros((bsz, n_head, qlen, shift)) x = jnp.concatenate([zero_pad, x], axis=3) x = x.reshape(bsz, n_head, 2 * klen - 1 + shift, qlen) x = x[:, :, shift:, :] x = x.reshape(bsz, n_head, qlen, klen * 2 - 1) x = x[:, :, :, shift - 1:shift - 1 + klen:k] return x
def init(self, w): momentum = [] if self._has_momentum: momentum = jnp.zeros_like(w) v1s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] v2s = [] if self._graft: v2s = [jnp.zeros(sz, dtype=w.dtype) for sz in w.shape] return (momentum, v1s, v2s)
def _fast_inference_init_state(input_signature, buffer_length): """Returns an initial state for causal attention layer fast inference.""" def zeros_for(batch_size, shape_dtype): shape, dtype = shape_dtype.as_tuple() d_feature = shape[-1] return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) batch_size = input_signature[0].shape[0] k = zeros_for(batch_size, input_signature[1]) v = zeros_for(batch_size, input_signature[2]) mask = jnp.zeros((batch_size, 1, buffer_length)) seq_indices = jnp.zeros((batch_size, ), dtype=jnp.int32) return (k, v, mask, seq_indices)
def init(self, weights): shape = weights.shape slots = [] if self._factored and len(shape) >= 2: v_row = jnp.zeros(shape[:-1], dtype=jnp.float32) v_col = jnp.zeros(shape[:-2] + shape[-1:], dtype=jnp.float32) slots.extend([v_row, v_col]) else: v = jnp.zeros_like(weights) slots.append(v) if self._do_momentum: m = jnp.zeros_like(weights) slots.append(m) return slots
def init_weights_and_state(self, input_signature): d_feature = input_signature.shape[-1] if self._transform == 'diag': # Initialize it to a small value because JAX has a bug in softplus. scale_isoftplus = jnp.zeros((d_feature,), dtype=jnp.float32) + 1e-4 weights = scale_isoftplus elif self._transform == 'any': ortho = trax.layers.initializers.OrthogonalInitializer() weights = ortho((d_feature, d_feature), self.rng) else: weights = layer_base.EMPTY_WEIGHTS if self._mode == 'predict': batch_size = input_signature.shape[0] self.state = jnp.zeros((batch_size,), dtype=jnp.int32), self.rng self.weights = weights
def ResidualZero(*layers, shortcut=None): """Wraps a series of layers with a ReZero-style residual connection. Instead of computing `(shortcut) + (output of layers)`, like in classical Residual connection, ResidualZero computes `(shortcut) + alpha * (output of layers)`, where `alpha` is a learnable scalar initialized with zero. Args: *layers: One or more layers, to be applied in series. shortcut: If None (the usual case), the Residual layer computes the element-wise sum of the stack-top input with the output of the layer series. If specified, the `shortcut` layer applies to a copy of the inputs and (elementwise) adds its output to the output from the main layer series. Returns: A layer representing a residual connection paired with a layer series. """ layers = _ensure_flat(layers) layer = layers[0] if len(layers) == 1 else tl.Serial(layers) # TODO(jaszczur): perhaps change inner Serial to Branch? return tl.Serial( tl.Branch( shortcut, tl.Serial( layer, tl.Weights( lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)), tl.Multiply())), tl.Add(), # pylint: disable=no-value-for-parameter )
def init_weights_and_state(self, input_signature): if self._mode == 'predict': shape, dtype = input_signature.as_tuple() batch_size, _, d_feature = shape cache = jnp.zeros((batch_size, 2 * self._total_kv_pooling, d_feature), dtype=dtype) self.state = cache, jnp.array(0)
def prepare_attention_input(encoder_activations, decoder_activations, inputs): """ function will prepare K, Q, V and M for attention layer. Args: encoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the input encoder decoder_activations fastnp.array(batch_size, padded_input_length, d_model): output from the pre-attention decoder inputs fastnp.array(batch_size, padded_input_length): padded input tokens Returns: queries, keys, values and mask for attention. """ # set the keys and values to the encoder activations keys = encoder_activations # (32, 64, 1024) values = encoder_activations # set the queries to the decoder activations queries = decoder_activations # generate the mask to distinguish real tokens from padding mask = inputs != 0 # --> (32, 64) # add axes to the mask for attention heads and decoder length. mask = fastnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) # (32, 1, 1, 64) # broadcast so mask shape is [batch size, attention heads, decoder-len, encoder-len]. mask = mask + fastnp.zeros((1, 1, decoder_activations.shape[1], 1)) # (32, 1, 64, 64) return queries, keys, values, mask
def f(x): # pylint: disable=invalid-name if len(x.shape) != 3: raise ValueError(f'Layer input should be a rank 3 tensor representing' f' (batch_size, sequence_length, feature_depth); ' f'instead got shape {x.shape}.') return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]), dtype=jnp.float32)
def init_weights_and_state(self, input_signature): """Randomly initializes the positional encoding vectors. Args: input_signature: :py:class:`ShapeDtype` instance characterizing the input this layer should compute on. """ d_feature = input_signature.shape[-1] if self._d_feature is not None: d_feature = self._d_feature pe = np.zeros((self._max_len, d_feature), dtype=np.float32) position = np.arange(0, self._max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) # [self._max_len, d_feature] if self._use_bfloat16: pe = pe.astype(jnp.bfloat16) w = jnp.array(pe) # Trainable parameters, initialized above. if self._d_feature is not None: ff = init.GlorotUniformInitializer()( (d_feature, input_signature.shape[-1]), self.rng) self.weights = w, ff else: self.weights = w if self._mode == 'predict': self.state = jnp.zeros((), dtype=jnp.int32)
def NoUpsampling(shorten_factor, d_model, *args, **kwargs): del d_model, args, kwargs return core.Fn( 'ReturnZero', lambda x: jnp.zeros( # pylint: disable=g-long-lambda (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype))
def init_weights_and_state(self, input_signature): """Randomly initializes the positional encoding vectors. Args: input_signature: `ShapeDtype` instance characterizing the input this layer should compute on. """ if self._mode == 'predict': self.state = jnp.zeros((), dtype=jnp.int32)
def F(encoder_activations, decoder_activations, input_tokens): keys = values = encoder_activations queries = decoder_activations # Mask is 1 where inputs are not padding (0) and 0 where they are padding. mask = (input_tokens != 0) # We need to add axes to the mask for attention heads and decoder length. mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) return queries, keys, values, mask
def _fast_matrix_shift(x): # Implements necessary shift for relative positional attention calculations. shift = 1 batch_size, n_head = x.shape[0], x.shape[1] queries_len, keys_len = x.shape[2], x.shape[3] zero_pad = jnp.zeros((batch_size, n_head, queries_len, shift)) x = jnp.concatenate([zero_pad, x], axis=3) x = x.reshape(batch_size, n_head, keys_len + shift, queries_len) x = x[:, :, shift:, :] return x
def _run_value_model(self, observations, dist_inputs): if dist_inputs is None: dist_inputs = jnp.zeros(observations.shape[:2] + (self._policy_dist.n_inputs, )) actions = None if self._q_value: if self._sample_all_discrete_actions: # Since we want to sample all actions, start by creating their list. act = np.arange(self._vocab_size) # Now act is a vector [0, ..., vocab_size-1], but we'll need to tile it. # Add extra dimenstions so it's the same dimensionality as dist_inputs. act = jnp.reshape(act, [-1] + [1] * (len(dist_inputs.shape) - 1)) # Now act is [vocab_size, 1, ..., 1], dimensionality of dist_inputs. dist_inputs = jnp.broadcast_to( dist_inputs, (self._q_value_n_samples, ) + dist_inputs.shape) if self._sample_all_discrete_actions: actions = act + jnp.zeros(dist_inputs.shape[:-1], dtype=jnp.int32) actions = jnp.swapaxes(actions, 0, 1) # Swapping the n_samples and batch_size axes, so the input is split # between accelerators along the batch_size axis. dist_inputs = jnp.swapaxes(dist_inputs, 0, 1) if not self._sample_all_discrete_actions: actions = self._policy_dist.sample(dist_inputs) log_probs = self._policy_dist.log_prob(dist_inputs, actions) obs = observations obs = jnp.reshape(obs, [obs.shape[0], 1] + list(obs.shape[1:])) inputs = (obs, actions) else: log_probs = None inputs = (observations, ) n_devices = fastmath.device_count() weights = tl.for_n_devices(self._value_eval_model.weights, n_devices) state = tl.for_n_devices(self._value_eval_model.state, n_devices) rng = self._value_eval_model.rng values, _ = self._value_eval_jit(inputs, weights, state, rng) values *= self._value_network_scale values = jnp.squeeze(values, axis=-1) # Remove the singleton depth dim. return (values, actions, log_probs)
def forward(self, inputs): """Returns the input activations, with added positional information.""" if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] if self._mode != 'train' or self._start_from_zero_prob >= 1.0: px = self.weights[:, :symbol_size, :] else: rng1, rng2 = fastmath.random.split(self.rng, 2) start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) start_from_zero = fastmath.random.uniform( rng2, (), jnp.float32, 0, 1) start = jnp.where(start_from_zero < self._start_from_zero_prob, jnp.zeros((), dtype=jnp.int32), start) px = fastmath.dynamic_slice_in_dim(self.weights, start, symbol_size, axis=1) if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer needs to store the index of the current # position then and increment it on each call -- that's how state is used # and updated below. state = self.state if inputs.shape[1] == 1: self.state = state + 1 return inputs + jnp.expand_dims(self.weights[0, state, :], 1) else: emb = [] for i in range(inputs.shape[0]): emb.append( fastmath.dynamic_slice_in_dim(self.weights[0], state[i], inputs.shape[1], axis=0)) self.state = state + inputs.shape[1] res = inputs + jnp.stack(emb, 0) return res
def f(decoder_input, mask): if len(decoder_input.shape) != 3: raise ValueError( f'Decoder input to EncoderDecoderMask must be a rank 3 tensor with ' f'shape (batch_size, decoder_sequence_length, d_model); instead got ' f'shape {decoder_input.shape}.') batch_size = mask.shape[0] encoder_sequence_length = mask.shape[-1] decoder_sequence_length = decoder_input.shape[1] mask = mask.reshape((batch_size, 1, 1, encoder_sequence_length)) return mask + jnp.zeros((1, 1, decoder_sequence_length, 1))
def init_weights_and_state(self, input_signature): """Helper to initialize batch norm weights and state.""" axis = self._axis axis = (axis,) if jnp.isscalar(axis) else axis input_shape = input_signature.shape shape = tuple(d for i, d in enumerate(input_shape) if i not in axis) # TODO(jonni): Should beta and gamma match the dtype in the input signature? beta = jnp.zeros(shape, dtype='float32') if self._center else () gamma = jnp.ones(shape, dtype='float32') if self._scale else () def get_stats_axis(i, d): if i in axis: return 1 else: return d stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape)) running_mean = jnp.zeros(stats_shape, dtype=jnp.float32) running_var = jnp.ones(stats_shape, dtype=jnp.float32) n_batches = jnp.zeros((), dtype=jnp.int64) self.weights = (beta, gamma) self.state = (running_mean, running_var, n_batches)
def test_custom_initializer_shape(self): layer = tl.Weights( lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32), (2, 2)) layer.init(()) y = layer(()) self.assertEqual(y.tolist(), [[0., 0.], [0., 0.]]) layer = tl.Weights(init.RandomNormalInitializer(), (2, 2)) layer.init(()) y = layer(()) self.assertEqual(y.shape, (2, 2)) self.assertNotEqual(y.tolist(), [[0., 0.], [0., 0.]])
def forward(self, inputs): """Returns the input activations, with added positional information.""" weights = self.weights if self._d_feature is not None: weights, ff = weights weights = jnp.dot(weights[:inputs.shape[1], :], ff) if len(weights.shape ) < 3: # old checkpoints have 1 in first dim already weights = weights[None, :, :] # [1, self._max_len, d_feature] if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] if self._mode != 'train' or self._start_from_zero_prob >= 1.0: px = weights[:, :symbol_size, :] else: rng1, rng2 = fastmath.random.split(self.rng, 2) start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) start_from_zero = fastmath.random.uniform( rng2, (), jnp.float32, 0, 1) start = jnp.where(start_from_zero < self._start_from_zero_prob, jnp.zeros((), dtype=jnp.int32), start) px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size, axis=1) if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer stores the index of the current # position and increments it on each call. emb = fastmath.dynamic_slice_in_dim(weights, self.state, inputs.shape[1], axis=1) self.state += inputs.shape[1] return inputs + emb
def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention
def test_weights_and_state_signature(self): class MyLayer(tl.Layer): def init_weights_and_state(self, input_signature): self.weights = jnp.zeros((2, 3)) self.state = jnp.ones(input_signature.shape) def forward(self, inputs): return self.weights + self.state layer = MyLayer() w, s = layer.weights_and_state_signature(jnp.zeros((3, 4))) self.assertEqual(w.shape, (2, 3)) self.assertEqual(s.shape, (3, 4))
def init_weights_and_state(self, input_signature): # Usually (B, W, H, C) shape = input_signature.shape num_channels = shape[-1] gamma = jnp.ones((num_channels, ), dtype=jnp.float32) beta = jnp.zeros((num_channels, ), dtype=jnp.float32) epsilon_l = base.EMPTY_WEIGHTS if self._learn_epsilon: epsilon_l = (self._init_learnt_epsilon, ) self.weights = gamma, beta, epsilon_l
def init_weights_and_state(self, input_signature): d_feature = input_signature.shape[-1] pe = np.zeros((self._max_len, d_feature), dtype=np.float32) position = np.arange(0, self._max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, self._max_len, d_feature] self.weights = jnp.array(pe) # Trainable parameters, initialized above. if self._mode == 'predict': batch_size = input_signature.shape[0] self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
def init_weights_and_state(self, input_signature): d_feature = input_signature.shape[-1] assert d_feature % self._n_digits == 0 d_weight = d_feature // self._n_digits rng1, rng2 = fastmath.random.split(self.rng, 2) base_weights = [[ self._initializer((1, d_weight), rng) for rng in fastmath.random.split(rng1, self._n_digits) ] for _ in self._bases] # Special vector to mark the starting position. start_vec = self._initializer((1, 1, d_feature), rng2) self.weights = (base_weights, start_vec) if self._mode == 'predict': self.state = jnp.zeros((), dtype=jnp.int32)
def prepare_attention_input(encoder_activations, decoder_activations, inputs): keys = encoder_activations values = encoder_activations queries = decoder_activations mask = (inputs != 0 ) # generate the mask to distinguish real tokens from padding mask = fastnp.reshape( mask, (mask.shape[0], 1, 1, mask.shape[1] )) # add axes to the mask for attention heads and decoder length. mask = mask + fastnp.zeros( (1, 1, decoder_activations.shape[1], 1) ) # broadcast so mask shape is [batch size, attention heads, decoder-len, encoder-len]. return queries, keys, values, mask
def _fast_inference_init_state(input_signature, buffer_length, predict_mask=None): """Returns an initial state for causal attention layer fast inference.""" def zeros_for(batch_size, shape_dtype): shape, dtype = shape_dtype.as_tuple() d_feature = shape[-1] return jnp.zeros((batch_size, buffer_length, d_feature), dtype=dtype) batch_size = input_signature[0].shape[0] k = zeros_for(batch_size, input_signature[1]) v = zeros_for(batch_size, input_signature[2]) if predict_mask is not None: mask_for_predict = jnp.zeros((buffer_length, )) != 0 return (mask_for_predict, k, v, jnp.array(0)) else: return (k, v, jnp.array(0))
def policy_inputs(self, trajectory, values): """Create inputs to policy model from a TrajectoryNp and values.""" # How much TD to use is determined by the added policy slice length, # as the policy batches need to be this much longer to calculate TD. advantages = self._advantage_estimator( rewards=trajectory.rewards, returns=trajectory.returns, values=values, dones=trajectory.dones, gamma=self._task.gamma, n_extra_steps=self._added_policy_slice_length, ) # Observations should be the same length as advantages - so if we are # using n_extra_steps, we need to trim the length to match. obs = trajectory.observations[:, :advantages.shape[1]] act = trajectory.actions[:, :advantages.shape[1]] mask = trajectory.mask[:, :advantages. shape[1]] # Mask to zero-out padding. if trajectory.dist_inputs is not None: dist_inputs = trajectory.dist_inputs[:, :advantages.shape[1]] else: dist_inputs = jnp.zeros(advantages.shape + (self._policy_dist.n_inputs, )) # Shape checks to help debugging. if len(advantages.shape) != 2: raise ValueError('Advantages are expected to have shape ' + '[batch_size, length], got: %s' % str(advantages.shape)) if act.shape[0:2] != advantages.shape: raise ValueError( 'First 2 dimensions of actions should be the same as in ' 'advantages, %s != %s' % (act.shape[0:2], advantages.shape)) if obs.shape[0:2] != advantages.shape: raise ValueError( 'First 2 dimensions of observations should be the same ' 'as in advantages, %s != %s' % (obs.shape[0:2], advantages.shape)) if dist_inputs.shape[:2] != advantages.shape: raise ValueError( 'First 2 dimensions of dist_inputs should be the same ' 'as in advantages, %s != %s' % (dist_inputs.shape[:2], advantages.shape)) if mask.shape != advantages.shape: raise ValueError('Mask and advantages shapes should be the same' ', %s != %s' % (mask.shape, advantages.shape)) return (obs, act, advantages, dist_inputs, mask)
def init_weights_and_state(self, input_signature): """Randomly initializes the positional encoding vectors. Args: input_signature: `ShapeDtype` instance characterizing the input this layer should compute on. """ d_feature = input_signature.shape[-1] pe = np.zeros((self._max_len, d_feature), dtype=np.float32) position = np.arange(0, self._max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, self._max_len, d_feature] self.weights = jnp.array(pe) # Trainable parameters, initialized above. if self._mode == 'predict': batch_size = input_signature.shape[0] self.state = jnp.zeros((batch_size,), dtype=jnp.int32)
def test_forward(self): layer = tl.PureLayer(lambda x: 2 * x) # Use Layer.__call__. in_0 = np.array([1, 2]) out_0 = layer(in_0, weights=jnp.zeros((2, 3))) self.assertEqual(out_0.tolist(), [2, 4]) self.assertEmpty(layer.weights) # Use PureLayer.forward. in_1 = np.array([3, 4]) out_1 = layer.forward(in_1) self.assertEqual(out_1.tolist(), [6, 8]) # Use Layer.pure_fn in_2 = np.array([5, 6]) out_2, _ = layer.pure_fn(in_2, tl.EMPTY_WEIGHTS, tl.EMPTY_WEIGHTS, None) self.assertEqual(out_2.tolist(), [10, 12])