Ejemplo n.º 1
0
class EncoderLayer(Model):
    def __init__(self, nM=300, nH=6, device='cpu'):
        Model.__init__(self)
        self.attn = MultiHeadedAttention(nM=nM, nH=nH)
        self.ffd = PositionwiseFeedForward(nM, 4 * nM)
        self.norm = PyTorchWrapper(PytorchLayerNorm(nM, device=device))
        self.nM = nM
        self.layers_ = [self.attn, self.ffd, self.norm]

    def begin_update(self, input, drop=0.1):
        X0, mask = input
        X1, b_X1 = self.attn.begin_update((X0, mask, None), drop=drop)
        X2, b_X2 = self.norm.begin_update(X1)
        X3 = X0 + X2

        X4, b_X4 = self.ffd.begin_update(X3, drop=drop)
        X5, b_X5 = self.norm.begin_update(X4)
        X6 = X3 + X5

        def finish_update(dX6, sgd=None):
            dX5 = dX6
            dX4 = b_X5(dX5, sgd=sgd)
            dX3 = b_X4(dX4, sgd=sgd)
            dX3 += dX6

            dX2 = dX3
            dX1 = b_X2(dX2, sgd=sgd)
            dX0 = b_X1(dX1, sgd=sgd)

            dX0 += dX3
            return X0

        return (X6, mask), finish_update
Ejemplo n.º 2
0
class DecoderLayer(Model):
    def __init__(self, nM=300, nH=6, device='cpu'):
        Model.__init__(self)
        self.y_attn = MultiHeadedAttention(nM=nM, nH=nH)
        self.x_attn = MultiHeadedAttention(nM=nM, nH=nH)
        self.norm = PyTorchWrapper(PytorchLayerNorm(nM, device=device))
        self.ffd = PositionwiseFeedForward(nM, 4 * nM)
        self.layers_ = [self.norm, self.y_attn, self.x_attn, self.ffd]

    def begin_update(self, input, drop=0.1):
        Y0, X0, X_mask, Y_mask = input
        Y1, b_Y1 = self.y_attn.begin_update((Y0, Y_mask, None), drop=drop)
        Y2, b_Y2 = self.norm.begin_update(Y1)
        Y3 = Y0 + Y2
        Y4, b_Y4 = self.x_attn.begin_update((Y3, X0, X_mask, None, None),
                                            drop=drop)
        Y5, b_Y5 = self.norm.begin_update(Y4)
        Y6 = Y3 + Y5
        Y7, b_Y7 = self.ffd.begin_update(Y6, drop=drop)

        def finish_update(dI, sgd=None):
            dY7, dX = dI
            dY6 = b_Y7(dY7, sgd=sgd)
            dY5 = dY6
            dY4 = b_Y5(dY5, sgd=sgd)
            dY3, dX0 = b_Y4(dY4, sgd=sgd)
            dY3 += dY6
            dY2 = dY3
            dY1 = b_Y2(dY2, sgd=sgd)
            dY0 = b_Y1(dY1, sgd=sgd)
            dY0 += dY3
            dX0 += dX
            return (dY0, dX0)

        return (Y7, X0, X_mask, Y_mask), finish_update
Ejemplo n.º 3
0
class Categorizer(Model):
    def __init__(self, nS=12, nM=768, nH=12, nO=2, device='cpu'):
        Model.__init__(self)
        self.nM = nM
        self.nO = nO
        self.nS = nS
        self.enc = clone(EncoderLayer(nM=nM, nH=nH, device=device), nS)
        self.affine = Affine(nI=nM, nO=nM)
        self.softmax = Softmax(nI=nM, nO=nO)
        self.norm = PyTorchWrapper(PytorchLayerNorm(nM=nM, device=device))
        self.slicer = PyTorchWrapper(PytorchSlicer())
        self.device = device
        self.layers_ = [self.enc]

    def begin_update(self, inputs, drop=0.0):
        X0, Xmask = inputs
        (
            X1,
            _,
        ), b_X1 = self.enc.begin_update((X0, Xmask))
        X2, b_X2 = self.norm.begin_update(X1)
        X3, b_X3 = self.slicer.begin_update(X2)
        X4, b_X4 = self.affine.begin_update(X3)
        X5, b_X5 = self.softmax.begin_update(X4)

        def finish_update(dX5, sgd=None):
            dX4 = b_X5(dX5, sgd=sgd)
            dX3 = b_X4(dX4, sgd=sgd)
            dX2 = b_X3(dX3)
            dX1 = b_X2(dX2, sgd=sgd)
            dX0 = b_X1(dX1, sgd=sgd)
            return dX0

        return X5, finish_update
Ejemplo n.º 4
0
def main(length=1000, nO=32, nI=32):
    pt_model = nn.Linear(nI, nO)
    optimizer = torch.optim.Adam(pt_model.parameters())

    model = PyTorchWrapper(pt_model)

    X = numpy.ones((length, nI), dtype='f')
    y = 1. / X
    for i in range(10):
        yh, get_dX = model.begin_update(X)
        dY = (yh - y) / len(y)
        dX = get_dX(dY)
Ejemplo n.º 5
0
def test_wrapper(nN=2, nI=3, nO=4):
    if PyTorchWrapper is None:
        return
    model = PyTorchWrapper(torch.nn.Linear(nI, nO))
    sgd = SGD(model.ops, 0.001)
    X = numpy.zeros((nN, nI), dtype="f")
    X += numpy.random.uniform(size=X.size).reshape(X.shape)
    Y = numpy.zeros((nN, nO), dtype="f")
    Yh, get_dX = model.begin_update(X)
    assert Yh.shape == (nN, nO)
    dYh = (Yh - Y) / Yh.shape[0]
    dX = get_dX(dYh, sgd=sgd)
    assert dX.shape == (nN, nI)
    check_learns_zero_output(model, sgd, X, Y)
Ejemplo n.º 6
0
def test_wrapper(nN=2, nI=3, nO=4):
    if PyTorchWrapper is None:
        return
    model = PyTorchWrapper(torch.nn.Linear(nI, nO))
    sgd = SGD(model.ops, 0.001)
    X = numpy.zeros((nN, nI), dtype='f')
    X += numpy.random.uniform(size=X.size).reshape(X.shape)
    Y = numpy.zeros((nN, nO), dtype='f')
    Yh, get_dX = model.begin_update(X)
    assert Yh.shape == (nN, nO)
    dYh = (Yh-Y) / Yh.shape[0]
    dX = get_dX(dYh, sgd=sgd)
    assert dX.shape == (nN, nI)
    check_learns_zero_output(model, sgd, X, Y)
Ejemplo n.º 7
0
class Encoder(Model):
    def __init__(self, nM=300, nH=6, nS=6, device='cpu'):
        Model.__init__(self)
        self.stack = clone(EncoderLayer(nM=nM, nH=nH, device=device), nS)
        self.norm = PyTorchWrapper(PytorchLayerNorm(nM=nM, device=device))

    def begin_update(self, input, drop=0.1):
        X0, mask = input
        (X1, _), b_X1 = self.stack.begin_update((X0, mask), drop=0.1)
        X2, b_X2 = self.norm.begin_update(X1)

        def finish_update(dX2, sgd=None):
            dX1 = b_X2(dX2, sgd=sgd)
            dX0 = b_X1(dX1, sgd=sgd)
            return dX0

        return X2, finish_update
Ejemplo n.º 8
0
class EncoderDecoder(Model):
    def __init__(self, nS=1, nH=6, nM=300, nTGT=10000, device='cpu'):
        '''
        EncoderDecoder consists of an encoder stack, a decoder stack and an
        output layer which is a linear + softmax.
        Parameters explanation:
            nS: the number of encoders/decoders in the stack
            nH: the number of heads in the multiheaded attention
            nM: the token's embedding size
            nTGT: the number of unique words in output vocabulary
        '''
        Model.__init__(self)
        self.nS = nS
        self.nH = nH
        self.nM = nM
        self.nTGT = nTGT
        self.device = device
        self.enc = Encoder(nM=nM, nH=nH, device=device, nS=nS)
        self.norm = PyTorchWrapper(PytorchLayerNorm(nM=nM, device=device))
        self.dec = clone(DecoderLayer(nM=nM, nH=nH, device=device), nS)
        self.proj = with_reshape(Softmax(nO=nTGT, nI=nM))
        self._layers = [self.enc, self.dec, self.proj]

    def begin_update(self, inputs, drop=0.1):
        '''
        A batch object flows through the network. It contains input, output and
        corresponding masks. Input changes while the object travels through
        the network. Output is the golden output.
        Input: nB x nL x nM
        '''
        X0, Xmask, Y0, Ymask = inputs
        X1, backprop_encode = self.enc.begin_update((X0, Xmask), drop=drop)
        (Y1, _, _, _), backprop_decode = self.dec.begin_update(
            (Y0, X1, Xmask, Ymask), drop=drop)
        Y2, b_Y2 = self.norm.begin_update(Y1)
        word_probs, backprop_output = self.proj.begin_update(Y2, drop=drop)

        def finish_update(d_word_probs, sgd=None):
            dY2 = backprop_output(d_word_probs, sgd=sgd)
            dY1 = b_Y2(dY2, sgd=sgd)
            zeros = Model.ops.xp.zeros(X0.shape, dtype=Model.ops.xp.float32)
            dY0, dX1 = backprop_decode((dY1, zeros), sgd=sgd)
            dX0 = backprop_encode(dX1, sgd=sgd)
            return (dX0, dY0)

        return (word_probs, Xmask), finish_update
Ejemplo n.º 9
0
def main(length=1000, nO=32, nI=32):
    if CupyOps.xp is not None:
        print("Use GPU")
        Model.ops = CupyOps()
        Model.Ops = CupyOps
        torch.set_default_tensor_type("torch.cuda.FloatTensor")

    pt_model = nn.Linear(nI, nO)
    optimizer = torch.optim.Adam(pt_model.parameters())  # noqa: F841

    model = PyTorchWrapper(pt_model)

    X = Model.ops.xp.ones((length, nI), dtype="f")
    y = 1.0 / X
    for i in range(10):
        yh, get_dX = model.begin_update(X)
        dY = (yh - y) / len(y)
        dX = get_dX(dY)  # noqa: F841
Ejemplo n.º 10
0
def main(depth=2, width=512, nb_epoch=30):
    prefer_gpu()
    torch.set_num_threads(1)

    train_data, dev_data, _ = datasets.mnist()
    train_X, train_y = Model.ops.unzip(train_data)
    dev_X, dev_y = Model.ops.unzip(dev_data)

    dev_y = to_categorical(dev_y)
    model = PyTorchWrapper(
        PyTorchFeedForward(
            depth=depth,
            width=width,
            input_size=train_X.shape[1],
            output_size=dev_y.shape[1],
        ))
    with model.begin_training(train_X, train_y,
                              L2=1e-6) as (trainer, optimizer):
        epoch_loss = [0.0]

        def report_progress():
            # with model.use_params(optimizer.averages):
            print(epoch_loss[-1], model.evaluate(dev_X, dev_y),
                  trainer.dropout)
            epoch_loss.append(0.0)

        trainer.each_epoch.append(report_progress)
        trainer.nb_epoch = nb_epoch
        trainer.dropout = 0.3
        trainer.batch_size = 128
        trainer.dropout_decay = 0.0
        train_X = model.ops.asarray(train_X, dtype="float32")
        y_onehot = to_categorical(train_y)
        for X, y in trainer.iterate(train_X, y_onehot):
            yh, backprop = model.begin_update(X, drop=trainer.dropout)
            loss = ((yh - y)**2.0).sum() / y.shape[0]
            backprop(yh - y, optimizer)
            epoch_loss[-1] += loss
        with model.use_params(optimizer.averages):
            print("Avg dev.: %.3f" % model.evaluate(dev_X, dev_y))
            with open("out.pickle", "wb") as file_:
                pickle.dump(model, file_, -1)
Ejemplo n.º 11
0
def main(length=1000, nO=32, nI=32):
    ''' Driver function '''
    if CupyOps.xp != None:
        print("Use GPU")
        Model.ops = CupyOps()
        Model.Ops = CupyOps
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        print("GPU not available. Running on CPU.")
    pt_model = nn.Linear(nI, nO)
    optimizer = torch.optim.Adam(pt_model.parameters())  # noqa: F841

    model = PyTorchWrapper(pt_model)

    X = Model.ops.xp.ones((length, nI), dtype="f")
    y = 1.0 / X
    for i in range(10):
        yh, get_dX = model.begin_update(X)
        dY = (yh - y) / len(y)
        dX = get_dX(dY)  # noqa: F841
Ejemplo n.º 12
0
def main(depth=2, width=512, nb_epoch=30):
    prefer_gpu()
    torch.set_num_threads(1)

    train_data, dev_data, _ = datasets.mnist()
    train_X, train_y = Model.ops.unzip(train_data)
    dev_X, dev_y = Model.ops.unzip(dev_data)

    dev_y = to_categorical(dev_y)
    model = PyTorchWrapper(
        PyTorchFeedForward(
            depth=depth,
            width=width,
            input_size=train_X.shape[1],
            output_size=dev_y.shape[1],
        )
    )
    with model.begin_training(train_X, train_y, L2=1e-6) as (trainer, optimizer):
        epoch_loss = [0.0]

        def report_progress():
            # with model.use_params(optimizer.averages):
            print(epoch_loss[-1], model.evaluate(dev_X, dev_y), trainer.dropout)
            epoch_loss.append(0.0)

        trainer.each_epoch.append(report_progress)
        trainer.nb_epoch = nb_epoch
        trainer.dropout = 0.3
        trainer.batch_size = 128
        trainer.dropout_decay = 0.0
        train_X = model.ops.asarray(train_X, dtype="float32")
        y_onehot = to_categorical(train_y)
        for X, y in trainer.iterate(train_X, y_onehot):
            yh, backprop = model.begin_update(X, drop=trainer.dropout)
            loss = ((yh - y) ** 2.0).sum() / y.shape[0]
            backprop(yh - y, optimizer)
            epoch_loss[-1] += loss
        with model.use_params(optimizer.averages):
            print("Avg dev.: %.3f" % model.evaluate(dev_X, dev_y))
            with open("out.pickle", "wb") as file_:
                pickle.dump(model, file_, -1)
Ejemplo n.º 13
0
class MultiHeadedAttention(Model):
    ''' This class implements multiheaded attention. It can be used for self
    attention or outer attention, depending on our needs. There is no left
    and right context width. We attend to the whole sentence and we take
    care of the masks to adjust appropriately. There are no actual different
    weight matrices for each head, but a bigger weight matrix for all heads.
    Going to bigger dimensions is the key to get the multiple heads.
    For the time being; key, query and value matrices are supposed to have the
    same length.
    '''
    def __init__(self, nM=300, nH=6):
        Model.__init__(self)
        self.nH = nH
        self.nM = nM  # model size: the length of the embeddings
        self.nD = nM // nH
        self.get_queries = with_reshape(Affine(nM, nM))
        self.get_keys = with_reshape(Affine(nM, nM))
        self.get_values = with_reshape(Affine(nM, nM))
        self.get_output = with_reshape(Affine(nM, nM))
        self._layers = [self.get_queries, self.get_keys, self.get_values, self.get_output]
        self._softmax = PyTorchWrapper(nn.Softmax(dim=-1))

        ''' mask conf '''
        i_grad = [1, 0]
        o_xp = None
        b_map = None
        ret_x = [0]
        conf = [i_grad, o_xp, b_map, ret_x]
        self._mask = PyTorchWrapper(PytorchMaskScores(), conf=conf)

    def begin_update(self, input, drop=0.1):
        # TESTED
        # Queries come from input[0], keys and values from input[1]
        if len(input) == 3:
            x0, mask, sentX = input
            sentY = sentX
            y0 = x0
            self_attention = True
        else:
            self_attention = False
            x0, y0, mask, sentX, sentY = input
        ''' Shapes '''
        # x0: nB, nL, nM
        # q0: nB, nL, nM
        # k0: nB, nL, nM
        # v0: nB, nL, nM
        # q1: nB, nH, nL, nD
        # k1: nB, nH, nL, nD
        # v1: nB, nH, nL, nD
        nB, nL, nD, nH = x0.shape[0], x0.shape[1], self.nD, self.nH
        q0, get_dx0 = self.get_queries.begin_update(x0)
        q1 = q0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3)
        k0, get_dy0_1 = self.get_keys.begin_update(y0)
        k1 = k0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3)
        v0, get_dy0_2 = self.get_values.begin_update(y0)
        v1 = v0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3)
        x1, get_dq1_dk1_dv1 = self.attn(q1, k1, v1, mask=mask, sentX=sentX,
                                        sentY=sentY, self_attn=self_attention)
        x2 = x1.transpose(0, 2, 1, 3).reshape((nB, nL, nH*nD))
        x3, get_dx2 = self.get_output.begin_update(x2)

        def finish_update(dx3, sgd=None):
            dx2 = get_dx2(dx3, sgd=sgd)
            dx1 = dx2.reshape((nB, nL, nH, nD)).transpose(0, 2, 1, 3)
            dq1, dk1, dv1 = get_dq1_dk1_dv1(dx1)
            nM = nH * nD
            dq0 = dq1.transpose(0, 2, 1, 3).reshape((nB, nL, nM))
            dk0 = dk1.transpose(0, 2, 1, 3).reshape((nB, nL, nM))
            dv0 = dv1.transpose(0, 2, 1, 3).reshape((nB, nL, nM))
            dy0 = get_dy0_2(dv0, sgd=sgd) + get_dy0_1(dk0, sgd=sgd)
            dx0 = get_dx0(dq0, sgd=sgd)
            if self_attention:
                return dx0 + dy0
            else:
                return (dx0, dy0)
        return x3, finish_update

    def attn(self, Q, K, V, mask=None, sentX=None, sentY=None, self_attn=True):
        '''
        Compute attention on (query, key, value) triplets.
        The similarity of the (Q, K) pairs are used to
        compute an attention matrix, which is used to rescale
        V.
        '''
        # TESTED
        S0, bp_scaled_dp = self._scaled_dot_prod(Q, K)
        S1, bp_mask = self._mask.begin_update((S0, mask))
        S2, bp_softmax = self._softmax.begin_update(S1)
        S3, bp_apply_attn = self._apply_attn(S2, V)

        def backprop_attn(dS3):
            ''' Attention three inputs, one output '''
            dS2, dV = bp_apply_attn(dS3)
            dS1 = bp_softmax(dS2)
            dS0 = bp_mask(dS1)
            dQ, dK = bp_scaled_dp(dS0)
            return dQ, dK, dV
        return S3, backprop_attn

    def _scaled_dot_prod(self, Q0, K0):
        # TESTED
        # Q0: nB, nH, nL, nD
        # K0: nB, nH, nL, nD
        nB, nH, nL, nD = Q0.shape
        # Q1: nB*nH, nL, nD
        Q1 = Q0.reshape((nB*nH, nL, nD))
        # K1: (nB*nH, nD, nL)
        K1 = K0.transpose(0, 1, 3, 2).reshape((nB*nH, nD, nL))
        # K2: (nB*nH, nD, nL)
        K2 = (K1 / self.ops.xp.sqrt(self.nM)).astype("float32")
        # S0: (nB*nH, nL, nL)
        S0 = self.ops.xp.matmul(Q1, K2)

        # S1 shape: (nB, nH, nL, nL)
        S1 = S0.reshape((nB, nH, nL, nL))

        def backprop_attn1(dS1):
            dS0 = dS1.reshape((nB*nH, nL, nL))
            dQ1 = self.ops.xp.matmul(dS0, K2.transpose(0, 2, 1))
            dK2 = self.ops.xp.matmul(Q1.transpose(0, 2, 1), dS0)
            dK1 = (dK2 / self.ops.xp.sqrt(self.nM)).astype("float32")
            dK0 = dK1.reshape((nB, nH, nD, nL)).transpose(0, 1, 3, 2)
            dQ0 = dQ1.reshape((nB, nH, nL, nD))
            return dQ0, dK0
        return S1, backprop_attn1

    # def _mask(self, S0, mask):
    #     S1 = S0.transpose(1, 0, 2, 3)
    #     S2 = S1 - (1 - mask) * (1e9)
    #     S3 = S2.transpose(1, 0, 2, 3)
    #
    #     def backprop_attn2(dS3):
    #         dS2 = dS3.transpose(1, 0, 2, 3)
    #         dS1 = dS2
    #         dS0 = dS1.transpose(1, 0, 2, 3)
    #         return dS0
    #
    #     return S3, backprop_attn2

    def _apply_attn(self, S0, V0):
        ''' Multiplication with values '''
        # TESTED
        # S0: (nB, nH, nL, nL)
        # VO: (nB, nH, nL, nD)
        # S1: (nB*nH, nL, nL)
        # V1:  (nB*nH, nL, nD)
        # S2: (nB*nH, nL, nD)
        # S3: (nB, nH, nL, nD)
        nB, nH, nL, nL = S0.shape
        nD = V0.shape[-1]
        V1 = V0.reshape((nB*nH, nL, nD))
        S1 = S0.reshape((nB*nH, nL, nL))
        S2 = self.ops.xp.matmul(S1, V1)

        S3 = S2.reshape((nB, nH, nL, nD))

        def backprop_attn4(dS3):
            dS2 = dS3.reshape((nB*nH, nL, nD))
            # (nB*nH, nL, nD) @ (nB*nH, nL, nD).T --> (nB*nH, nL, nL)
            dS1 = self.ops.xp.matmul(dS2, V1.transpose(0, 2, 1))
            # (nB*nH, nL, nL).T @ (nB*nH, nL, nD) --> (nB*nH, nL, nD)
            dV1 = self.ops.xp.matmul(S1.transpose(0, 2, 1), dS2)
            dS0 = dS1.reshape((nB, nH, nL, nL))
            dV0 = dV1.reshape((nB, nH, nL, nD))
            return dS0, dV0

        return S3, backprop_attn4
Ejemplo n.º 14
0
class SparseAttention(Model):
    ''' This class implements multiheaded attention in steps, factorizing
    the attention matrix. '''
    def __init__(self, nM=300, nH=6):
        self.nH = nH
        self.nM = nM  # model size: the length of the embeddings
        self.nD = nM // nH
        self.get_queries = with_reshape(Affine(nM, nM))
        self.get_keys = with_reshape(Affine(nM, nM))
        self.get_values = with_reshape(Affine(nM, nM))
        self.get_output = with_reshape(Affine(nM, nM))
        self._layers = [self.get_queries, self.get_keys, self.get_values, self.get_output]
        self._softmax = PyTorchWrapper(nn.Softmax(dim=-1))
        ''' mask conf '''
        i_grad = [1, 0]
        o_xp = None
        b_map = None
        ret_x = [0]
        conf = [i_grad, o_xp, b_map, ret_x]
        self._mask = PyTorchWrapper(PytorchMaskScores(), conf=conf)

    def begin_update(self, input, drop=0.1):
        if len(input) == 3:
            x0, mask, sentX = input
            sentY = sentX
            y0 = x0
            self_attention = True
        else:
            self_attention = False
            x0, y0, mask, sentX, sentY = input
        ''' Shapes '''
        # x0: nB, nL, nM
        # q0: nB, nL, nM
        # k0: nB, nL, nM
        # v0: nB, nL, nM
        # q1: nB, nH, nL, nD
        # k1: nB, nH, nL, nD
        # v1: nB, nH, nL, nD
        nB, nL, nD, nH = x0.shape[0], x0.shape[1], self.nD, self.nH
        q0, get_dx0 = self.get_queries.begin_update(x0)
        q1 = q0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3)
        k0, get_dy0_1 = self.get_keys.begin_update(y0)
        k1 = k0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3)
        v0, get_dy0_2 = self.get_values.begin_update(y0)
        v1 = v0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3)
        q2, get_dq1_dk1_dv1 = self.attn(q1, k1, v1, mask=mask & self.mask_floor(nB, nL), sentX=sentX,
                                        sentY=sentY, self_attn=self_attention)
        x1, get_dq2_dk1_dv1 = self.attn(q1, k1, v1, mask=mask & self.mask_repetitive(nB, nL), sentX=sentX,
                                        sentY=sentY, self_attn=self_attention)

        x2 = x1.transpose(0, 2, 1, 3).reshape((nB, nL, nH*nD))
        x3, get_dx2 = self.get_output.begin_update(x2)

        def finish_update(dx3, sgd=None):
            dx2 = get_dx2(dx3, sgd=sgd)
            dx1 = dx2.reshape((nB, nL, nH, nD)).transpose(0, 2, 1, 3)
            dq2, dk11, dv11 = get_dq2_dk1_dv1(dx1)
            dq1, dk12, dv12 = get_dq1_dk1_dv1(dq2)
            dk1 = dk11 + dk12
            dv1 = dv11 + dv12
            nM = nH * nD
            dq0 = dq1.transpose(0, 2, 1, 3).reshape((nB, nL, nM))
            dk0 = dk1.transpose(0, 2, 1, 3).reshape((nB, nL, nM))
            dv0 = dv1.transpose(0, 2, 1, 3).reshape((nB, nL, nM))
            dy0 = get_dy0_2(dv0, sgd=sgd) + get_dy0_1(dk0, sgd=sgd)
            dx0 = get_dx0(dq0, sgd=sgd)
            if self_attention:
                return dx0 + dy0
            else:
                return (dx0, dy0)
        return x3, finish_update

    def attn(self, Q, K, V, mask=None, sentX=None, sentY=None, self_attn=True):
        '''
        Compute attention on (query, key, value) triplets.
        The similarity of the (Q, K) pairs are used to
        compute an attention matrix, which is used to rescale
        V.
        '''
        # TESTED
        S0, bp_scaled_dp = self._scaled_dot_prod(Q, K)
        S1, bp_mask = self._mask.begin_update((S0, mask))
        S2, bp_softmax = self._softmax.begin_update(S1)
        S3, bp_apply_attn = self._apply_attn(S2, V)

        def backprop_attn(dS3):
            ''' Attention three inputs, one output '''
            dS2, dV = bp_apply_attn(dS3)
            dS1 = bp_softmax(dS2)
            dS0 = bp_mask(dS1)
            dQ, dK = bp_scaled_dp(dS0)
            return dQ, dK, dV
        return S3, backprop_attn

    def _scaled_dot_prod(self, Q0, K0):
        # TESTED
        # Q0: nB, nH, nL, nD
        # K0: nB, nH, nL, nD
        nB, nH, nL, nD = Q0.shape
        # Q1: nB*nH, nL, nD
        Q1 = Q0.reshape((nB*nH, nL, nD))
        # K1: (nB*nH, nD, nL)
        K1 = K0.transpose(0, 1, 3, 2).reshape((nB*nH, nD, nL))
        # K2: (nB*nH, nD, nL)
        K2 = (K1 / self.ops.xp.sqrt(self.nM)).astype("float32")
        # S0: (nB*nH, nL, nL)
        S0 = self.ops.xp.matmul(Q1, K2)

        # S1 shape: (nB, nH, nL, nL)
        S1 = S0.reshape((nB, nH, nL, nL))

        def backprop_attn1(dS1):
            dS0 = dS1.reshape((nB*nH, nL, nL))
            dQ1 = self.ops.xp.matmul(dS0, K2.transpose(0, 2, 1))
            dK2 = self.ops.xp.matmul(Q1.transpose(0, 2, 1), dS0)
            dK1 = (dK2 / self.ops.xp.sqrt(self.nM)).astype("float32")
            dK0 = dK1.reshape((nB, nH, nD, nL)).transpose(0, 1, 3, 2)
            dQ0 = dQ1.reshape((nB, nH, nL, nD))
            return dQ0, dK0
        return S1, backprop_attn1

    def _apply_attn(self, S0, V0):
        ''' Multiplication with values '''
        # TESTED
        # S0: (nB, nH, nL, nL)
        # VO: (nB, nH, nL, nD)
        # S1: (nB*nH, nL, nL)
        # V1:  (nB*nH, nL, nD)
        # S2: (nB*nH, nL, nD)
        # S3: (nB, nH, nL, nD)
        nB, nH, nL, nL = S0.shape
        nD = V0.shape[-1]
        V1 = V0.reshape((nB*nH, nL, nD))
        S1 = S0.reshape((nB*nH, nL, nL))
        S2 = self.ops.xp.matmul(S1, V1)

        S3 = S2.reshape((nB, nH, nL, nD))

        def backprop_attn4(dS3):
            dS2 = dS3.reshape((nB*nH, nL, nD))
            # (nB*nH, nL, nD) @ (nB*nH, nL, nD).T --> (nB*nH, nL, nL)
            dS1 = self.ops.xp.matmul(dS2, V1.transpose(0, 2, 1))
            # (nB*nH, nL, nL).T @ (nB*nH, nL, nD) --> (nB*nH, nL, nD)
            dV1 = self.ops.xp.matmul(S1.transpose(0, 2, 1), dS2)
            dS0 = dS1.reshape((nB, nH, nL, nL))
            dV0 = dV1.reshape((nB, nH, nL, nD))
            return dS0, dV0

        return S3, backprop_attn4

    def mask_floor(nB, nL):
        stride = math.ceil(math.sqrt(nL))
        floor_mask = Model.ops.xp.zeros((nB, nL, nL), dtype=Model.ops.xp.uint8)
        for i in range(nL):
            lower = max(0, i - (i % stride))
            higher = i + 1
            floor_mask[:, i, lower:higher] = 1
        return floor_mask

    def mask_repetitive(nB, nL):
        ''' Every stride tokens, mask one (independent of row) '''
        stride = math.ceil(math.sqrt(nL))
        repetitive_mask = Model.ops.xp.zeros((nB, nL, nL), dtype=Model.ops.xp.uint8)
        for j in range(nL):
            if ((j % stride) >= (stride - c)):
                if mode == 'left':
                    repetitive_mask[:, j:, j] = 1
        return repetitive_mask