def test_one_hot(self): targets = np.array([2, 0, 1]) n_categories = 5 target_distributions = tl.one_hot(targets, n_categories) self.assertEqual( tl.to_list(target_distributions), [[0., 0., 1., 0., 0.], [1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]])
def log_prob(self, inputs, point): inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs)) return jnp.sum( # Select the logits specified by point. inputs * tl.one_hot(point, self._n_categories), # Sum over the parameter dimensions. axis=[-a for a in range(1, len(self._shape) + 2)], )
def _symbols_to_logits(self, symbols): # Assert that symbols are discrete. assert np.issubdtype(symbols.dtype, np.integer) # Assert that 0 <= symbols < vocab_size. np.testing.assert_array_less(-1, symbols) np.testing.assert_array_less(symbols, self._vocab_size) # Return almost-determinisitc logits: # e^1000 / (e^1000 + vocab_size) ~= 1 return tl.one_hot(symbols, n_categories=self._vocab_size) * 1000.0
def test_model(preds, target): """Function to test the model. Args: preds (jax.interpreters.xla.DeviceArray): Predictions of a list of batches of tensors corresponding to lines of text. target (jax.interpreters.xla.DeviceArray): Actual list of batches of tensors corresponding to lines of text. Returns: float: log_perplexity of the model. """ ### START CODE HERE (Replace instances of 'None' with your code) ### total_log_ppx = np.sum(preds*tl.one_hot(target,preds.shape[-1]) , axis= -1) # HINT: tl.one_hot() should replace one of the Nones non_pad = 1.0 - np.equal(target, 0) # You should check if the target equals 0 ppx = total_log_ppx * non_pad # Get rid of the padding log_ppx = np.sum(ppx) / np.sum(non_pad) ### END CODE HERE ### return -log_ppx
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': w1 = jnp.reshape(w1.T, (-1, self._d_ff)) w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: # quant_mask is [joint_batch, self._d1] # w1 is [d_model, self._d1, self._d2] # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 idx1 = jnp.reshape(idx1, [-1]) idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim w = jnp.reshape(w, [batch_size, self._d1, -1]) mid = jnp.einsum('ai,aji->aj', x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] v = jnp.reshape(v, [batch_size, self._d1, -1]) res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, w1, w2, b2 = self.weights x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: check if we need bias and/or put relu after the m1 dot? mask_logits = jnp.dot(x, m1) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. # TODO(lukaszkaiser, chowdhery): Extract this block and share rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (50% of the batches) use the soft-mask instead of # the quantized mask to improve training stability (see the paper above). # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) quant_mask = jnp.where(select > 0.0, quant_mask, mask) else: quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1]) quant_mask_shape = quant_mask.shape batch_size = quant_mask.shape[0] if self._mode == 'predict' and batch_size == 1: # This implementation mimicks inference for batch_size 1. start_idx = selected_experts[0] * self._n_elements_in_block # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] w = fastmath.dynamic_slice( w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block]) mid = jnp.dot(x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] v = fastmath.dynamic_slice( w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]]) v = jnp.reshape(v, [self._n_elements_in_block, -1]) res = jnp.dot(relu, v) + b2 else: expanded_mask = jnp.broadcast_to( quant_mask, (quant_mask_shape[0], quant_mask.shape[1], self._n_elements_in_block)) expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
targets = np.array(targets) # Print shapes print(f'predictions has shape: {predictions.shape}') print(f'targets has shape: {targets.shape}') # Notice that the predictions have an extra dimension with the same length as the size of the vocabulary used. # # Because of this you will need a way of reshaping `targets` to match this shape. For this you can use `trax.layers.one_hot()`. # # Notice that `predictions.shape[-1]` will return the size of the last dimension of `predictions`. # In[5]: reshaped_targets = tl.one_hot( targets, predictions.shape[-1] ) #trax's one_hot function takes the input as one_hot(x, n_categories, dtype=optional) print(f'reshaped_targets has shape: {reshaped_targets.shape}') # By calculating the product of the predictions and the reshaped targets and summing across the last dimension, the total log perplexity can be computed: # In[6]: total_log_ppx = np.sum(predictions * reshaped_targets, axis=-1) # Now you will need to account for the padding so this metric is not artificially deflated (since a lower perplexity means a better model). For identifying which elements are padding and which are not, you can use `np.equal()` and get a tensor with `1s` in the positions of actual values and `0s` where there are paddings. # In[7]: non_pad = 1.0 - np.equal(targets, 0) print(f'non_pad has shape: {non_pad.shape}\n')