Exemplo n.º 1
0
 def f(model_output, target_category):  # pylint: disable=invalid-name
   shapes.assert_same_shape(model_output, target_category)
   batch_size = model_output.shape[0]
   j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output))
   j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output))
   j = -1.0/batch_size * jnp.squeeze(j)
   return j
Exemplo n.º 2
0
  def forward(self, inputs):
    x, gru_state = inputs

    # Dense layer on the concatenation of x and h.
    w1, b1, w2, b2 = self.weights
    y = jnp.dot(jnp.concatenate([x, gru_state], axis=-1), w1) + b1

    # Update and reset gates.
    u, r = jnp.split(fastmath.sigmoid(y), 2, axis=-1)

    # Candidate.
    c = jnp.dot(jnp.concatenate([x, r * gru_state], axis=-1), w2) + b2

    new_gru_state = u * gru_state + (1 - u) * jnp.tanh(c)
    return new_gru_state, new_gru_state
Exemplo n.º 3
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input, except the final dimension
      is the layer's `n_units` value.
    """
        if self._use_bias:
            if not isinstance(self.weights, (tuple, list)):
                raise ValueError(f'Weights should be a (w, b) tuple or list; '
                                 f'instead got: {self.weights}')
            w, b = self.weights
            return jnp.dot(x, w) + b  # Affine map.
        else:
            w = self.weights
            return jnp.dot(x, w)  # Linear map.
Exemplo n.º 4
0
def predict(question1,
            question2,
            threshold,
            model,
            vocab,
            data_generator=data_generator,
            verbose=False):
    """Function for predicting if two questions are duplicates.

    Args:
        question1 (str): First question.
        question2 (str): Second question.
        threshold (float): Desired threshold.
        model (trax.layers.combinators.Parallel): The Siamese model.
        vocab (collections.defaultdict): The vocabulary used.
        data_generator (function): Data generator function. Defaults to data_generator.
        verbose (bool, optional): If the results should be printed out. Defaults to False.

    Returns:
        bool: True if the questions are duplicates, False otherwise.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###
    # use `nltk` word tokenize function to tokenize
    q1 = nltk.word_tokenize(question1)  # tokenize
    q2 = nltk.word_tokenize(question2)  # tokenize
    Q1, Q2 = [], []
    for word in q1:  # encode q1
        # increment by checking the 'word' index in `vocab`
        Q1 += [vocab[word]]
    for word in q2:  # encode q2
        # increment by checking the 'word' index in `vocab`
        Q2 += [vocab[word]]

    # Call the data generator (built in Ex 01) using next()
    # pass [Q1] & [Q2] as Q1 & Q2 arguments of the data generator. Set batch size as 1
    # Hint: use `vocab['<PAD>']` for the `pad` argument of the data generator
    Q1, Q2 = next(data_generator([Q1], [Q2], batch_size=1, pad=vocab['<PAD>']))
    # Call the model
    v1, v2 = model((Q1, Q2))
    # take dot product to compute cos similarity of each pair of entries, v1, v2
    # don't forget to transpose the second argument
    d = fastnp.dot(v1, v2.T)
    # is d greater than the threshold?
    res = d > threshold

    ### END CODE HERE ###

    if (verbose):
        print("Q1  = ", Q1, "\nQ2  = ", Q2)
        print("d   = ", d)
        print("res = ", res)

    return res
Exemplo n.º 5
0
def classify(test_Q1,
             test_Q2,
             y,
             threshold,
             model,
             vocab,
             data_generator=data_generator,
             batch_size=64):
    """Function to test the accuracy of the model.

    Args:
        test_Q1 (numpy.ndarray): Array of Q1 questions.
        test_Q2 (numpy.ndarray): Array of Q2 questions.
        y (numpy.ndarray): Array of actual target.
        threshold (float): Desired threshold.
        model (trax.layers.combinators.Parallel): The Siamese model.
        vocab (collections.defaultdict): The vocabulary used.
        data_generator (function): Data generator function. Defaults to data_generator.
        batch_size (int, optional): Size of the batches. Defaults to 64.

    Returns:
        float: Accuracy of the model.
    """
    accuracy = 0
    ### START CODE HERE (Replace instances of 'None' with your code) ###
    for i in range(0, len(test_Q1), batch_size):
        # Call the data generator (built in Ex 01) with shuffle=False using next()
        # use batch size chuncks of questions as Q1 & Q2 arguments of the data generator. e.g x[i:i + batch_size]
        # Hint: use `vocab['<PAD>']` for the `pad` argument of the data generator
        q1, q2 = next(
            data_generator(test_Q1[i:i + batch_size],
                           test_Q2[i:i + batch_size],
                           batch_size,
                           pad=vocab['<PAD>'],
                           shuffle=False))
        # use batch size chuncks of actual output targets (same syntax as example above)
        y_test = y[i:i + batch_size]
        # Call the model
        v1, v2 = model((q1, q2))

        for j in range(batch_size):
            # take dot product to compute cos similarity of each pair of entries, v1[j], v2[j]
            # don't forget to transpose the second argument
            d = fastnp.dot(v1[j], v2[j].T)
            # is d greater than the threshold?
            res = d > threshold
            # increment accurancy if y_test is equal `res`
            accuracy += y_test[j] == res
    # compute accuracy using accuracy and total length of test questions
    accuracy = accuracy / len(test_Q1)
    ### END CODE HERE ###

    return accuracy
Exemplo n.º 6
0
    def forward(self, inputs):
        """Returns the input activations, with added positional information."""
        weights = self.weights
        if self._d_feature is not None:
            weights, ff = weights
            weights = jnp.dot(weights[:inputs.shape[1], :], ff)
        if len(weights.shape
               ) < 3:  # old checkpoints have 1 in first dim already
            weights = weights[None, :, :]  # [1, self._max_len, d_feature]
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            if self._mode != 'train' or self._start_from_zero_prob >= 1.0:
                px = weights[:, :symbol_size, :]
            else:
                rng1, rng2 = fastmath.random.split(self.rng, 2)
                start = fastmath.random.randint(rng1, (), 0,
                                                self._max_offset_to_add)
                start_from_zero = fastmath.random.uniform(
                    rng2, (), jnp.float32, 0, 1)
                start = jnp.where(start_from_zero < self._start_from_zero_prob,
                                  jnp.zeros((), dtype=jnp.int32), start)
                px = fastmath.dynamic_slice_in_dim(weights,
                                                   start,
                                                   symbol_size,
                                                   axis=1)
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                keep = fastmath.random.bernoulli(self.rng, keep_prob,
                                                 tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer stores the index of the current
            # position and increments it on each call.
            emb = fastmath.dynamic_slice_in_dim(weights,
                                                self.state,
                                                inputs.shape[1],
                                                axis=1)
            self.state += inputs.shape[1]
            return inputs + emb
Exemplo n.º 7
0
  def forward(self, inputs):
    x, lstm_state = inputs

    # LSTM state consists of c and h.
    c, h = jnp.split(lstm_state, 2, axis=-1)

    # Dense layer on the concatenation of x and h.
    w, b = self.weights
    y = jnp.dot(jnp.concatenate([x, h], axis=-1), w) + b

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = jnp.split(y, 4, axis=-1)

    new_c = c * fastmath.sigmoid(f) + fastmath.sigmoid(i) * jnp.tanh(j)
    new_h = jnp.tanh(new_c) * fastmath.sigmoid(o)
    return new_h, jnp.concatenate([new_c, new_h], axis=-1)
Exemplo n.º 8
0
def TripletLossFn(v1, v2, margin=0.25):
    """Custom Loss function.

    Args:
        v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1.
        v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2.
        margin (float, optional): Desired margin. Defaults to 0.25.

    Returns:
        jax.interpreters.xla.DeviceArray: Triplet Loss.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###

    # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument)
    scores = fastnp.dot(v1, fastnp.transpose(v2))  # pairwise cosine sim
    # calculate new batch size
    batch_size = len(scores)
    # use fastnp to grab all postive `diagonal` entries in `scores`
    positive = fastnp.diagonal(scores)  # the positive ones (duplicates)
    # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores`
    negative_without_positive = scores - fastnp.eye(batch_size)
    # take the row by row `max` of `negative_without_positive`.
    # Hint: negative_without_positive.max(axis = [?])
    closest_negative = negative_without_positive.max(axis=[1])
    # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores`
    negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores
    # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)`
    mean_negative = fastnp.sum(negative_zero_on_duplicate,
                               axis=1) / (batch_size - 1)
    # compute `fastnp.maximum` among 0.0 and `A`
    # A = subtract `positive` from `margin` and add `closest_negative`
    triplet_loss1 = fastnp.maximum((margin - positive + closest_negative), 0.0)
    # compute `fastnp.maximum` among 0.0 and `B`
    # B = subtract `positive` from `margin` and add `mean_negative`
    triplet_loss2 = fastnp.maximum((margin - positive + mean_negative), 0.0)
    # add the two losses together and take the `fastnp.mean` of it
    triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2)

    ### END CODE HERE ###

    return triplet_loss
Exemplo n.º 9
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [d_model, self._d1, self._d2]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
        else:
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Exemplo n.º 10
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
        initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, w1, w2, b2 = self.weights
        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: check if we need bias and/or put relu after the m1 dot?
        mask_logits = jnp.dot(x, m1)
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        # TODO(lukaszkaiser, chowdhery): Extract this block and share
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        selected_experts = jnp.argmax(log_mask + g * self._temperature,
                                      axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (50% of the batches) use the soft-mask instead of
            # the quantized mask to improve training stability (see the paper above).
            # Q: is selecting 50% of batches the best? Other %? Mixed in-batch?
            select = fastmath.random.uniform(rng2, (), jnp.float32, -1.0, 1.0)
            quant_mask = jnp.where(select > 0.0, quant_mask, mask)
        else:
            quant_mask = tl.one_hot(selected_experts, self._num_experts)
        quant_mask = jnp.reshape(quant_mask, [-1, self._num_experts, 1])
        quant_mask_shape = quant_mask.shape
        batch_size = quant_mask.shape[0]

        if self._mode == 'predict' and batch_size == 1:
            # This implementation mimicks inference for batch_size 1.
            start_idx = selected_experts[0] * self._n_elements_in_block
            # w1 is [d_model, d_ff], w is [d_model, n_elements_in_block]
            w = fastmath.dynamic_slice(
                w1, [0, start_idx], [w1.shape[0], self._n_elements_in_block])
            mid = jnp.dot(x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [d_ff, d_model], v is [n_elements_in_block, d_model]
            v = fastmath.dynamic_slice(
                w2, [start_idx, 0], [self._n_elements_in_block, w2.shape[-1]])
            v = jnp.reshape(v, [self._n_elements_in_block, -1])
            res = jnp.dot(relu, v) + b2
        else:
            expanded_mask = jnp.broadcast_to(
                quant_mask, (quant_mask_shape[0], quant_mask.shape[1],
                             self._n_elements_in_block))
            expanded_mask = jnp.reshape(expanded_mask, (-1, self._d_ff))
            mid = jnp.dot(x, w1) * expanded_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Exemplo n.º 11
0
#
# The process is pretty straightforward:
#    - Iterate over each one of the elements in the batch
#    - Compute the cosine similarity between the predictions
#        - For computing the cosine similarity, the two output vectors should have been normalized using L2 normalization meaning their magnitude will be 1. This has been taken care off by the Siamese network you will build in the assignment. Hence the cosine similarity here is just dot product between two vectors. You can check by implementing the usual cosine similarity formula and check if this holds or not.
#    - Determine if this value is greater than the threshold (If it is, consider the two questions as the same and return 1 else 0)
#    - Compare against the actual target and if the prediction matches, add 1 to the accuracy (increment the correct prediction counter)
#    - Divide the accuracy by the number of processed elements

# In[8]:

for j in range(
        batch_size):  # Iterate over each one of the elements in the batch

    d = np.dot(
        v1[j], v2[j]
    )  # Compute the cosine similarity between the predictions as l2 normalized, ||v1[j]||==||v2[j]||==1 so only dot product is needed
    res = d > threshold  # Determine if this value is greater than the threshold (if it is consider the two questions as the same)
    accuracy += (
        y_test[j] == res
    )  # Compare against the actual target and if the prediction matches, add 1 to the accuracy

accuracy = accuracy / batch_size  # Divide the accuracy by the number of processed elements

# In[9]:

print(f'The accuracy of the model is: {accuracy}')

# **Congratulations on finishing this lecture notebook!**
#
# Now you should have a clearer understanding of how to evaluate your Siamese language models using the accuracy metric.