def f(model_output, target_category): # pylint: disable=invalid-name shapes.assert_same_shape(model_output, target_category) batch_size = model_output.shape[0] j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) j = -1.0/batch_size * jnp.squeeze(j) return j
def forward(self, inputs): x, gru_state = inputs # Dense layer on the concatenation of x and h. w1, b1, w2, b2 = self.weights y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1 # Update and reset gates. u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1) # Candidate. c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2 new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c) return new_gru_state, new_gru_state
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input, except the final dimension is the layer's `n_units` value. """ if self._use_bias: if not isinstance(self.weights, (tuple, list)): raise ValueError(f'Weights should be a (w, b) tuple or list; ' f'instead got: {self.weights}') w, b = self.weights return jnp.dot(x, w) + b # Affine map. else: w = self.weights return jnp.dot(x, w) # Linear map.
def predict(question1, question2, threshold, model, vocab, data_generator=data_generator, verbose=False): """Function for predicting if two questions are duplicates. Args: question1 (str): First question. question2 (str): Second question. threshold (float): Desired threshold. model (trax.layers.combinators.Parallel): The Siamese model. vocab (collections.defaultdict): The vocabulary used. data_generator (function): Data generator function. Defaults to data_generator. verbose (bool, optional): If the results should be printed out. Defaults to False. Returns: bool: True if the questions are duplicates, False otherwise. """ ### START CODE HERE (Replace instances of 'None' with your code) ### # use `nltk` word tokenize function to tokenize q1 = nltk.word_tokenize(question1) # tokenize q2 = nltk.word_tokenize(question2) # tokenize Q1, Q2 = [], [] for word in q1: # encode q1 # increment by checking the 'word' index in `vocab` Q1 += [vocab[word]] for word in q2: # encode q2 # increment by checking the 'word' index in `vocab` Q2 += [vocab[word]] # Call the data generator (built in Ex 01) using next() # pass [Q1] & [Q2] as Q1 & Q2 arguments of the data generator. Set batch size as 1 # Hint: use `vocab['<PAD>']` for the `pad` argument of the data generator Q1, Q2 = next(data_generator([Q1], [Q2], batch_size=1, pad=vocab['<PAD>'])) # Call the model v1, v2 = model((Q1, Q2)) # take dot product to compute cos similarity of each pair of entries, v1, v2 # don't forget to transpose the second argument d = fastnp.dot(v1, v2.T) # is d greater than the threshold? res = d > threshold ### END CODE HERE ### if (verbose): print("Q1 = ", Q1, "\nQ2 = ", Q2) print("d = ", d) print("res = ", res) return res
def classify(test_Q1, test_Q2, y, threshold, model, vocab, data_generator=data_generator, batch_size=64): """Function to test the accuracy of the model. Args: test_Q1 (numpy.ndarray): Array of Q1 questions. test_Q2 (numpy.ndarray): Array of Q2 questions. y (numpy.ndarray): Array of actual target. threshold (float): Desired threshold. model (trax.layers.combinators.Parallel): The Siamese model. vocab (collections.defaultdict): The vocabulary used. data_generator (function): Data generator function. Defaults to data_generator. batch_size (int, optional): Size of the batches. Defaults to 64. Returns: float: Accuracy of the model. """ accuracy = 0 ### START CODE HERE (Replace instances of 'None' with your code) ### for i in range(0, len(test_Q1), batch_size): # Call the data generator (built in Ex 01) with shuffle=False using next() # use batch size chuncks of questions as Q1 & Q2 arguments of the data generator. e.g x[i:i + batch_size] # Hint: use `vocab['<PAD>']` for the `pad` argument of the data generator q1, q2 = next( data_generator(test_Q1[i:i + batch_size], test_Q2[i:i + batch_size], batch_size, pad=vocab['<PAD>'], shuffle=False)) # use batch size chuncks of actual output targets (same syntax as example above) y_test = y[i:i + batch_size] # Call the model v1, v2 = model((q1, q2)) for j in range(batch_size): # take dot product to compute cos similarity of each pair of entries, v1[j], v2[j] # don't forget to transpose the second argument d = fastnp.dot(v1[j], v2[j].T) # is d greater than the threshold? res = d > threshold # increment accurancy if y_test is equal `res` accuracy += y_test[j] == res # compute accuracy using accuracy and total length of test questions accuracy = accuracy / len(test_Q1) ### END CODE HERE ### return accuracy
def forward(self, inputs): """Returns the input activations, with added positional information.""" weights = self.weights if self._d_feature is not None: weights, ff = weights weights = jnp.dot(weights[:inputs.shape[1], :], ff) if len(weights.shape ) < 3: # old checkpoints have 1 in first dim already weights = weights[None, :, :] # [1, self._max_len, d_feature] if self._mode != 'predict': x = inputs symbol_size = jnp.shape(x)[1] if self._mode != 'train' or self._start_from_zero_prob >= 1.0: px = weights[:, :symbol_size, :] else: rng1, rng2 = fastmath.random.split(self.rng, 2) start = fastmath.random.randint(rng1, (), 0, self._max_offset_to_add) start_from_zero = fastmath.random.uniform( rng2, (), jnp.float32, 0, 1) start = jnp.where(start_from_zero < self._start_from_zero_prob, jnp.zeros((), dtype=jnp.int32), start) px = fastmath.dynamic_slice_in_dim(weights, start, symbol_size, axis=1) if self._dropout == 0: return x + px else: noise_shape = list(px.shape) for dim in self._dropout_broadcast_dims: noise_shape[dim] = 1 keep_prob = 1.0 - self._dropout keep = fastmath.random.bernoulli(self.rng, keep_prob, tuple(noise_shape)) multiplier = keep.astype(x.dtype) / keep_prob return x + px * multiplier else: if self._dropout != 0: raise ValueError(f'In predict mode, but dropout rate ' f'({self._dropout}) is not zero.') # State in this class is only used for fast inference. In that case, # the model is called with consecutive elements position-by-position. # This positional encoding layer stores the index of the current # position and increments it on each call. emb = fastmath.dynamic_slice_in_dim(weights, self.state, inputs.shape[1], axis=1) self.state += inputs.shape[1] return inputs + emb
def forward(self, inputs): x, lstm_state = inputs # LSTM state consists of c and h. c, h = jnp.split(lstm_state, 2, axis=-1) # Dense layer on the concatenation of x and h. w, b = self.weights y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = jnp.split(y, 4, axis=-1) new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j) new_h = jnp.tanh(new_c) * fastmath.sigmoid(o) return new_h, jnp.concatenate([new_c, new_h], axis=-1)
def TripletLossFn(v1, v2, margin=0.25): """Custom Loss function. Args: v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1. v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2. margin (float, optional): Desired margin. Defaults to 0.25. Returns: jax.interpreters.xla.DeviceArray: Triplet Loss. """ ### START CODE HERE (Replace instances of 'None' with your code) ### # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument) scores = fastnp.dot(v1, fastnp.transpose(v2)) # pairwise cosine sim # calculate new batch size batch_size = len(scores) # use fastnp to grab all postive `diagonal` entries in `scores` positive = fastnp.diagonal(scores) # the positive ones (duplicates) # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores` negative_without_positive = scores - fastnp.eye(batch_size) # take the row by row `max` of `negative_without_positive`. # Hint: negative_without_positive.max(axis = [?]) closest_negative = negative_without_positive.max(axis=[1]) # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores` negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)` mean_negative = fastnp.sum(negative_zero_on_duplicate, axis=1) / (batch_size - 1) # compute `fastnp.maximum` among 0.0 and `A` # A = subtract `positive` from `margin` and add `closest_negative` triplet_loss1 = fastnp.maximum((margin - positive + closest_negative), 0.0) # compute `fastnp.maximum` among 0.0 and `B` # B = subtract `positive` from `margin` and add `mean_negative` triplet_loss2 = fastnp.maximum((margin - positive + mean_negative), 0.0) # add the two losses together and take the `fastnp.mean` of it triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2) ### END CODE HERE ### return triplet_loss
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, m2, mb, w1, w2, b2 = self.weights if self._mode != 'predict': w1 = jnp.reshape(w1.T, (-1, self._d_ff)) w2 = jnp.reshape(w2, (self._d_ff, -1)) x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: should we add bias and/or put relu after the low-rank m1 dot? mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2]) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (quant_prob of the batches) use the soft-mask instead # of the quantized mask to improve training stability (see paper above). select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0) quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) if self._mode == 'train': # In training, run full matmul to get benefits from the above tricks. mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 elif self._mode == 'predict': # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1)) # w2 = jnp.reshape(w2, (self._d1, self._d2, -1)) # This implementation mimicks inference. It's not efficient for large # size of joint_batch, but at inference that will be 1 most of the time. # Shapes: # quant_mask is [joint_batch, self._d1] # w1 is [d_model, self._d1, self._d2] # we'll index w1 with advanced numpy indexing, first range over # self._d1 times the batch size, second range being quant_mask batch_size = quant_mask.shape[0] idx1 = jnp.array([jnp.arange(self._d1)] * batch_size) # flatten indices and select from w1 idx1 = jnp.reshape(idx1, [-1]) idx2 = jnp.reshape(quant_mask, [-1]) w = w1[idx1, idx2, :] # now we have per-element weights with batch dim w = jnp.reshape(w, [batch_size, self._d1, -1]) mid = jnp.einsum('ai,aji->aj', x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [self._d1, self._d2, d_model] v = w2[idx1, idx2, :] v = jnp.reshape(v, [batch_size, self._d1, -1]) res = jnp.einsum('ai,aij->aj', relu, v) + b2 else: quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block) quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff]) mid = jnp.dot(x, w1) * quant_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
def forward(self, x): """Executes this layer as part of a forward pass through the model. Args: x: Tensor of same shape and dtype as the input signature used to initialize this layer. Returns: Tensor of same shape and dtype as the input. """ m1, w1, w2, b2 = self.weights x_shape = x.shape x = jnp.reshape(x, [-1, x_shape[-1]]) # Easier to operate on flattened x. # Q: check if we need bias and/or put relu after the m1 dot? mask_logits = jnp.dot(x, m1) # Softmax. mask_logsumexp = fastmath.logsumexp(mask_logits, axis=-1, keepdims=True) log_mask = mask_logits - mask_logsumexp mask = jnp.exp(log_mask) # Gumbel-softmax with straight-through discretization. # TODO(lukaszkaiser, chowdhery): Extract this block and share rng1, rng2 = fastmath.random.split(self.rng, 2) u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6, 1.0 - 1e-6) g = -jnp.log(-jnp.log(u)) selected_experts = jnp.argmax(log_mask + g * self._temperature, axis=-1) if self._mode == 'train': # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797 quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = fastmath.stop_gradient(quant_mask) quant_mask += mask - fastmath.stop_gradient( mask) # straight-through # We will sometimes (50% of the batches) use the soft-mask instead of # the quantized mask to improve training stability (see the paper above). # Q: is selecting 50% of batches the best? Other %? Mixed in-batch? select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0) quant_mask = jnp.where(select > 0.0, quant_mask, mask) else: quant_mask = tl.one_hot(selected_experts, self._num_experts) quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1]) quant_mask_shape = quant_mask.shape batch_size = quant_mask.shape[0] if self._mode == 'predict' and batch_size == 1: # This implementation mimicks inference for batch_size 1. start_idx = selected_experts[0] * self._n_elements_in_block # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block] w = fastmath.dynamic_slice( w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block]) mid = jnp.dot(x, w) relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model] v = fastmath.dynamic_slice( w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]]) v = jnp.reshape(v, [self._n_elements_in_block, -1]) res = jnp.dot(relu, v) + b2 else: expanded_mask = jnp.broadcast_to( quant_mask, (quant_mask_shape[0], quant_mask.shape[1], self._n_elements_in_block)) expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff)) mid = jnp.dot(x, w1) * expanded_mask # [joint_batch, d_ff] relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid) res = jnp.dot(relu, w2) + b2 return jnp.reshape(res, x_shape) # un-flatten if needed
# # The process is pretty straightforward: # - Iterate over each one of the elements in the batch # - Compute the cosine similarity between the predictions # - For computing the cosine similarity, the two output vectors should have been normalized using L2 normalization meaning their magnitude will be 1. This has been taken care off by the Siamese network you will build in the assignment. Hence the cosine similarity here is just dot product between two vectors. You can check by implementing the usual cosine similarity formula and check if this holds or not. # - Determine if this value is greater than the threshold (If it is, consider the two questions as the same and return 1 else 0) # - Compare against the actual target and if the prediction matches, add 1 to the accuracy (increment the correct prediction counter) # - Divide the accuracy by the number of processed elements # In[8]: for j in range( batch_size): # Iterate over each one of the elements in the batch d = np.dot( v1[j], v2[j] ) # Compute the cosine similarity between the predictions as l2 normalized, ||v1[j]||==||v2[j]||==1 so only dot product is needed res = d > threshold # Determine if this value is greater than the threshold (if it is consider the two questions as the same) accuracy += ( y_test[j] == res ) # Compare against the actual target and if the prediction matches, add 1 to the accuracy accuracy = accuracy / batch_size # Divide the accuracy by the number of processed elements # In[9]: print(f'The accuracy of the model is: {accuracy}') # **Congratulations on finishing this lecture notebook!** # # Now you should have a clearer understanding of how to evaluate your Siamese language models using the accuracy metric.