Ejemplo n.º 1
0
 def f(model_output, target_category):  # pylint: disable=invalid-name
   shapes.assert_same_shape(model_output, target_category)
   batch_size = model_output.shape[0]
   j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output))
   j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output))
   j = -1.0/batch_size * jnp.squeeze(j)
   return j
Ejemplo n.º 2
0
    def Init(shape, rng):
        """Returns orthogonalized random normal values with the given `shape`."""
        # Have at least 2 elements in shape.
        cur_shape = list(shape)
        while len(cur_shape) < 2:
            cur_shape = [1] + cur_shape

        # Flatten the input shape with the last dimension remaining.
        n_rows = 1
        for dim in cur_shape[:-1]:
            n_rows *= dim
        n_cols = cur_shape[-1]
        flat_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols)

        # Generate a random matrix
        a = random.normal(rng, flat_shape, dtype=jnp.float32)

        # Compute the qr factorization
        q, r = jnp.linalg.qr(a)

        # Make Q uniform
        d = jnp.diag(r)
        q *= jnp.sign(d)

        # Transpose and reshape back q if needed.
        if n_rows < n_cols:
            q = jnp.transpose(q)
        q = jnp.reshape(q, shape)

        # Return scaled as requested.
        return stddev * q
Ejemplo n.º 3
0
 def compute_attention_heads(x):
     # Data reshaping for the model layers
     batch_size = x.shape[0]
     seqlen = x.shape[1]
     x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
     x = jnp.transpose(x, (0, 2, 1, 3))
     x = jnp.reshape(x, (-1, seqlen, d_head))
     return x
Ejemplo n.º 4
0
 def compute_attention_output(x):
     """ Compute the attention output.
     Args:
         x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size X n_heads, seqlen, d_head).
     Returns:
         jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size, seqlen, n_heads X d_head).
     """
     # Length of the sequence
     seqlen = x.shape[1]
     batch_size = int(x.shape[0] / n_heads)
     x = jnp.reshape(x, (batch_size, n_heads, seqlen, d_head))
     x = jnp.transpose(x, (0, 2, 1, 3))
     return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
Ejemplo n.º 5
0
 def compute_attention_heads(x):
     """ Compute the attention heads.
     Args:
         x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size, seqlen, n_heads X d_head).
     Returns:
         jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen, d_head).
     """
     batch_size = x.shape[0]
     seqlen = x.shape[1]
     x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
     x = jnp.transpose(x, (0, 2, 1, 3))
     x = jnp.reshape(x, (batch_size * n_heads, seqlen, d_head))
     return x
Ejemplo n.º 6
0
    def compute_attention_output(x):
        """ Compute the attention output.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size X n_heads, seqlen, d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size, seqlen, n_heads X d_head).
        """

        # Length of the sequence
        # Should be size of x's first dimension without counting the batch dim
        seqlen = x.shape[1]
        # Reshape x using jnp.reshape() to shape (batch_size, n_heads, seqlen, d_head)
        x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
        # Transpose x using jnp.transpose() to shape (batch_size, seqlen, n_heads, d_head)
        x = jnp.transpose(x, (0, 2, 1, 3))

        # Reshape to allow to concatenate the heads
        return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
Ejemplo n.º 7
0
def TripletLossFn(v1, v2, margin=0.25):
    """Custom Loss function.

    Args:
        v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1.
        v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2.
        margin (float, optional): Desired margin. Defaults to 0.25.

    Returns:
        jax.interpreters.xla.DeviceArray: Triplet Loss.
    """
    ### START CODE HERE (Replace instances of 'None' with your code) ###

    # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument)
    scores = fastnp.dot(v1, fastnp.transpose(v2))  # pairwise cosine sim
    # calculate new batch size
    batch_size = len(scores)
    # use fastnp to grab all postive `diagonal` entries in `scores`
    positive = fastnp.diagonal(scores)  # the positive ones (duplicates)
    # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores`
    negative_without_positive = scores - fastnp.eye(batch_size)
    # take the row by row `max` of `negative_without_positive`.
    # Hint: negative_without_positive.max(axis = [?])
    closest_negative = negative_without_positive.max(axis=[1])
    # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores`
    negative_zero_on_duplicate = (1.0 - fastnp.eye(batch_size)) * scores
    # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)`
    mean_negative = fastnp.sum(negative_zero_on_duplicate,
                               axis=1) / (batch_size - 1)
    # compute `fastnp.maximum` among 0.0 and `A`
    # A = subtract `positive` from `margin` and add `closest_negative`
    triplet_loss1 = fastnp.maximum((margin - positive + closest_negative), 0.0)
    # compute `fastnp.maximum` among 0.0 and `B`
    # B = subtract `positive` from `margin` and add `mean_negative`
    triplet_loss2 = fastnp.maximum((margin - positive + mean_negative), 0.0)
    # add the two losses together and take the `fastnp.mean` of it
    triplet_loss = fastnp.mean(triplet_loss1 + triplet_loss2)

    ### END CODE HERE ###

    return triplet_loss
Ejemplo n.º 8
0
    def compute_attention_heads(x):
        """ Compute the attention heads.
        Args:
            x (jax.interpreters.xla.DeviceArray): tensor with shape (batch_size, seqlen, n_heads X d_head).
        Returns:
            jax.interpreters.xla.DeviceArray: reshaped tensor with shape (batch_size X n_heads, seqlen, d_head).
        """

        batch_size, seqlen = x.shape[0], x.shape[1]
        # Reshape x using jnp.reshape()
        # batch_size, seqlen, n_heads*d_head -> batch_size, seqlen, n_heads, d_head
        x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
        # Transpose x using jnp.transpose()
        # batch_size, seqlen, n_heads, d_head -> batch_size, n_heads, seqlen, d_head
        # Note that the values within the tuple are the indexes of the dimensions of x and you must rearrange them
        x = jnp.transpose(x, (0, 2, 1, 3))
        # Reshape x using jnp.reshape()
        # batch_size, n_heads, seqlen, d_head -> batch_size*n_heads, seqlen, d_head
        x = jnp.reshape(x, (-1, seqlen, d_head))

        return x
Ejemplo n.º 9
0
 def swapaxes(x):
     transposed_axes = list(range(len(x.shape)))
     transposed_axes[axis] = 0
     transposed_axes[0] = axis
     return jnp.transpose(x, axes=transposed_axes)
Ejemplo n.º 10
0
    def forward(self, x):
        """Executes this layer as part of a forward pass through the model.

    Args:
      x: Tensor of same shape and dtype as the input signature used to
          initialize this layer.

    Returns:
      Tensor of same shape and dtype as the input.
    """
        m1, m2, mb, w1, w2, b2 = self.weights
        if self._mode != 'predict':
            w1 = jnp.reshape(w1.T, (-1, self._d_ff))
            w2 = jnp.reshape(w2, (self._d_ff, -1))
        else:
            # This is a work-around of a bug in the previous if statement, which makes
            # w1 array shuffled. Fixing it properly would invalidate previous
            # checkpoints, so this is a temporary work-around.
            w1 = jnp.transpose(w1, (1, 0, 2))
            w1 = jnp.reshape(w1, (self._d1, self._d2, -1))

        x_shape = x.shape
        x = jnp.reshape(x,
                        [-1, x_shape[-1]])  # Easier to operate on flattened x.

        # Q: should we add bias and/or put relu after the low-rank m1 dot?
        mask_logits = jnp.dot(jnp.dot(x, m1), m2) + mb
        mask_logits = jnp.reshape(mask_logits, [-1, self._d1, self._d2])
        # Softmax.
        mask_logsumexp = fastmath.logsumexp(mask_logits,
                                            axis=-1,
                                            keepdims=True)
        log_mask = mask_logits - mask_logsumexp
        mask = jnp.exp(log_mask)
        # Gumbel-softmax with straight-through discretization.
        rng1, rng2 = fastmath.random.split(self.rng, 2)
        u = fastmath.random.uniform(rng1, mask.shape, jnp.float32, 1e-6,
                                    1.0 - 1e-6)
        g = -jnp.log(-jnp.log(u))
        quant_mask = jnp.argmax(log_mask + g * self._temperature, axis=-1)
        if self._mode == 'train':
            # Tricks from Section 2.1 in https://arxiv.org/abs/1801.09797
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = fastmath.stop_gradient(quant_mask)
            quant_mask += mask - fastmath.stop_gradient(
                mask)  # straight-through
            # We will sometimes (quant_prob of the batches) use the soft-mask instead
            # of the quantized mask to improve training stability (see paper above).
            select = fastmath.random.uniform(rng2, (), jnp.float32, 0.0, 1.0)
            quant_mask = jnp.where(select < self._quant_prob, quant_mask, mask)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])

        if self._mode == 'train':
            # In training, run full matmul to get benefits from the above tricks.
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2
        elif self._mode == 'predict':
            # w1 = jnp.reshape(w1.T, (self._d1, self._d2, -1))
            # w2 = jnp.reshape(w2, (self._d1, self._d2, -1))
            # This implementation mimicks inference. It's not efficient for large
            # size of joint_batch, but at inference that will be 1 most of the time.
            # Shapes:
            # quant_mask is [joint_batch, self._d1]
            # w1 is [self._d1, self._d2, d_model]
            # we'll index w1 with advanced numpy indexing, first range over
            # self._d1 times the batch size, second range being quant_mask
            batch_size = quant_mask.shape[0]
            idx1 = jnp.array([jnp.arange(self._d1)] * batch_size)
            # flatten indices and select from w1
            idx1 = jnp.reshape(idx1, [-1])
            idx2 = jnp.reshape(quant_mask, [-1])
            w = w1[idx1,
                   idx2, :]  # now we have per-element weights with batch dim
            w = jnp.reshape(w, [batch_size, self._d1, -1])
            mid = jnp.einsum('ai,aji->aj', x, w)
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            # w2 is [self._d1, self._d2, d_model]
            v = w2[idx1, idx2, :]
            v = jnp.reshape(v, [batch_size, self._d1, -1])
            res = jnp.einsum('ai,aij->aj', relu, v) + b2
        else:
            quant_mask = tl.one_hot(quant_mask, self._n_elements_in_block)
            quant_mask = jnp.reshape(quant_mask, [-1, self._d_ff])
            mid = jnp.dot(x, w1) * quant_mask  # [joint_batch, d_ff]
            relu = jnp.where(mid <= 0, jnp.zeros_like(mid), mid)
            res = jnp.dot(relu, w2) + b2

        return jnp.reshape(res, x_shape)  # un-flatten if needed
Ejemplo n.º 11
0
 def compute_attention_output(x):
     # Data reshaping for the model layers
     seqlen = x.shape[1]
     x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
     x = jnp.transpose(x, (0, 2, 1, 3))
     return jnp.reshape(x, (-1, seqlen, n_heads * d_head))