class EncoderLayer(Model): def __init__(self, nM=300, nH=6, device='cpu'): Model.__init__(self) self.attn = MultiHeadedAttention(nM=nM, nH=nH) self.ffd = PositionwiseFeedForward(nM, 4 * nM) self.norm = PyTorchWrapper(PytorchLayerNorm(nM, device=device)) self.nM = nM self.layers_ = [self.attn, self.ffd, self.norm] def begin_update(self, input, drop=0.1): X0, mask = input X1, b_X1 = self.attn.begin_update((X0, mask, None), drop=drop) X2, b_X2 = self.norm.begin_update(X1) X3 = X0 + X2 X4, b_X4 = self.ffd.begin_update(X3, drop=drop) X5, b_X5 = self.norm.begin_update(X4) X6 = X3 + X5 def finish_update(dX6, sgd=None): dX5 = dX6 dX4 = b_X5(dX5, sgd=sgd) dX3 = b_X4(dX4, sgd=sgd) dX3 += dX6 dX2 = dX3 dX1 = b_X2(dX2, sgd=sgd) dX0 = b_X1(dX1, sgd=sgd) dX0 += dX3 return X0 return (X6, mask), finish_update
class DecoderLayer(Model): def __init__(self, nM=300, nH=6, device='cpu'): Model.__init__(self) self.y_attn = MultiHeadedAttention(nM=nM, nH=nH) self.x_attn = MultiHeadedAttention(nM=nM, nH=nH) self.norm = PyTorchWrapper(PytorchLayerNorm(nM, device=device)) self.ffd = PositionwiseFeedForward(nM, 4 * nM) self.layers_ = [self.norm, self.y_attn, self.x_attn, self.ffd] def begin_update(self, input, drop=0.1): Y0, X0, X_mask, Y_mask = input Y1, b_Y1 = self.y_attn.begin_update((Y0, Y_mask, None), drop=drop) Y2, b_Y2 = self.norm.begin_update(Y1) Y3 = Y0 + Y2 Y4, b_Y4 = self.x_attn.begin_update((Y3, X0, X_mask, None, None), drop=drop) Y5, b_Y5 = self.norm.begin_update(Y4) Y6 = Y3 + Y5 Y7, b_Y7 = self.ffd.begin_update(Y6, drop=drop) def finish_update(dI, sgd=None): dY7, dX = dI dY6 = b_Y7(dY7, sgd=sgd) dY5 = dY6 dY4 = b_Y5(dY5, sgd=sgd) dY3, dX0 = b_Y4(dY4, sgd=sgd) dY3 += dY6 dY2 = dY3 dY1 = b_Y2(dY2, sgd=sgd) dY0 = b_Y1(dY1, sgd=sgd) dY0 += dY3 dX0 += dX return (dY0, dX0) return (Y7, X0, X_mask, Y_mask), finish_update
class Categorizer(Model): def __init__(self, nS=12, nM=768, nH=12, nO=2, device='cpu'): Model.__init__(self) self.nM = nM self.nO = nO self.nS = nS self.enc = clone(EncoderLayer(nM=nM, nH=nH, device=device), nS) self.affine = Affine(nI=nM, nO=nM) self.softmax = Softmax(nI=nM, nO=nO) self.norm = PyTorchWrapper(PytorchLayerNorm(nM=nM, device=device)) self.slicer = PyTorchWrapper(PytorchSlicer()) self.device = device self.layers_ = [self.enc] def begin_update(self, inputs, drop=0.0): X0, Xmask = inputs ( X1, _, ), b_X1 = self.enc.begin_update((X0, Xmask)) X2, b_X2 = self.norm.begin_update(X1) X3, b_X3 = self.slicer.begin_update(X2) X4, b_X4 = self.affine.begin_update(X3) X5, b_X5 = self.softmax.begin_update(X4) def finish_update(dX5, sgd=None): dX4 = b_X5(dX5, sgd=sgd) dX3 = b_X4(dX4, sgd=sgd) dX2 = b_X3(dX3) dX1 = b_X2(dX2, sgd=sgd) dX0 = b_X1(dX1, sgd=sgd) return dX0 return X5, finish_update
def main(length=1000, nO=32, nI=32): pt_model = nn.Linear(nI, nO) optimizer = torch.optim.Adam(pt_model.parameters()) model = PyTorchWrapper(pt_model) X = numpy.ones((length, nI), dtype='f') y = 1. / X for i in range(10): yh, get_dX = model.begin_update(X) dY = (yh - y) / len(y) dX = get_dX(dY)
def test_wrapper(nN=2, nI=3, nO=4): if PyTorchWrapper is None: return model = PyTorchWrapper(torch.nn.Linear(nI, nO)) sgd = SGD(model.ops, 0.001) X = numpy.zeros((nN, nI), dtype="f") X += numpy.random.uniform(size=X.size).reshape(X.shape) Y = numpy.zeros((nN, nO), dtype="f") Yh, get_dX = model.begin_update(X) assert Yh.shape == (nN, nO) dYh = (Yh - Y) / Yh.shape[0] dX = get_dX(dYh, sgd=sgd) assert dX.shape == (nN, nI) check_learns_zero_output(model, sgd, X, Y)
def test_wrapper(nN=2, nI=3, nO=4): if PyTorchWrapper is None: return model = PyTorchWrapper(torch.nn.Linear(nI, nO)) sgd = SGD(model.ops, 0.001) X = numpy.zeros((nN, nI), dtype='f') X += numpy.random.uniform(size=X.size).reshape(X.shape) Y = numpy.zeros((nN, nO), dtype='f') Yh, get_dX = model.begin_update(X) assert Yh.shape == (nN, nO) dYh = (Yh-Y) / Yh.shape[0] dX = get_dX(dYh, sgd=sgd) assert dX.shape == (nN, nI) check_learns_zero_output(model, sgd, X, Y)
class Encoder(Model): def __init__(self, nM=300, nH=6, nS=6, device='cpu'): Model.__init__(self) self.stack = clone(EncoderLayer(nM=nM, nH=nH, device=device), nS) self.norm = PyTorchWrapper(PytorchLayerNorm(nM=nM, device=device)) def begin_update(self, input, drop=0.1): X0, mask = input (X1, _), b_X1 = self.stack.begin_update((X0, mask), drop=0.1) X2, b_X2 = self.norm.begin_update(X1) def finish_update(dX2, sgd=None): dX1 = b_X2(dX2, sgd=sgd) dX0 = b_X1(dX1, sgd=sgd) return dX0 return X2, finish_update
class EncoderDecoder(Model): def __init__(self, nS=1, nH=6, nM=300, nTGT=10000, device='cpu'): ''' EncoderDecoder consists of an encoder stack, a decoder stack and an output layer which is a linear + softmax. Parameters explanation: nS: the number of encoders/decoders in the stack nH: the number of heads in the multiheaded attention nM: the token's embedding size nTGT: the number of unique words in output vocabulary ''' Model.__init__(self) self.nS = nS self.nH = nH self.nM = nM self.nTGT = nTGT self.device = device self.enc = Encoder(nM=nM, nH=nH, device=device, nS=nS) self.norm = PyTorchWrapper(PytorchLayerNorm(nM=nM, device=device)) self.dec = clone(DecoderLayer(nM=nM, nH=nH, device=device), nS) self.proj = with_reshape(Softmax(nO=nTGT, nI=nM)) self._layers = [self.enc, self.dec, self.proj] def begin_update(self, inputs, drop=0.1): ''' A batch object flows through the network. It contains input, output and corresponding masks. Input changes while the object travels through the network. Output is the golden output. Input: nB x nL x nM ''' X0, Xmask, Y0, Ymask = inputs X1, backprop_encode = self.enc.begin_update((X0, Xmask), drop=drop) (Y1, _, _, _), backprop_decode = self.dec.begin_update( (Y0, X1, Xmask, Ymask), drop=drop) Y2, b_Y2 = self.norm.begin_update(Y1) word_probs, backprop_output = self.proj.begin_update(Y2, drop=drop) def finish_update(d_word_probs, sgd=None): dY2 = backprop_output(d_word_probs, sgd=sgd) dY1 = b_Y2(dY2, sgd=sgd) zeros = Model.ops.xp.zeros(X0.shape, dtype=Model.ops.xp.float32) dY0, dX1 = backprop_decode((dY1, zeros), sgd=sgd) dX0 = backprop_encode(dX1, sgd=sgd) return (dX0, dY0) return (word_probs, Xmask), finish_update
def main(length=1000, nO=32, nI=32): if CupyOps.xp is not None: print("Use GPU") Model.ops = CupyOps() Model.Ops = CupyOps torch.set_default_tensor_type("torch.cuda.FloatTensor") pt_model = nn.Linear(nI, nO) optimizer = torch.optim.Adam(pt_model.parameters()) # noqa: F841 model = PyTorchWrapper(pt_model) X = Model.ops.xp.ones((length, nI), dtype="f") y = 1.0 / X for i in range(10): yh, get_dX = model.begin_update(X) dY = (yh - y) / len(y) dX = get_dX(dY) # noqa: F841
def main(depth=2, width=512, nb_epoch=30): prefer_gpu() torch.set_num_threads(1) train_data, dev_data, _ = datasets.mnist() train_X, train_y = Model.ops.unzip(train_data) dev_X, dev_y = Model.ops.unzip(dev_data) dev_y = to_categorical(dev_y) model = PyTorchWrapper( PyTorchFeedForward( depth=depth, width=width, input_size=train_X.shape[1], output_size=dev_y.shape[1], )) with model.begin_training(train_X, train_y, L2=1e-6) as (trainer, optimizer): epoch_loss = [0.0] def report_progress(): # with model.use_params(optimizer.averages): print(epoch_loss[-1], model.evaluate(dev_X, dev_y), trainer.dropout) epoch_loss.append(0.0) trainer.each_epoch.append(report_progress) trainer.nb_epoch = nb_epoch trainer.dropout = 0.3 trainer.batch_size = 128 trainer.dropout_decay = 0.0 train_X = model.ops.asarray(train_X, dtype="float32") y_onehot = to_categorical(train_y) for X, y in trainer.iterate(train_X, y_onehot): yh, backprop = model.begin_update(X, drop=trainer.dropout) loss = ((yh - y)**2.0).sum() / y.shape[0] backprop(yh - y, optimizer) epoch_loss[-1] += loss with model.use_params(optimizer.averages): print("Avg dev.: %.3f" % model.evaluate(dev_X, dev_y)) with open("out.pickle", "wb") as file_: pickle.dump(model, file_, -1)
def main(length=1000, nO=32, nI=32): ''' Driver function ''' if CupyOps.xp != None: print("Use GPU") Model.ops = CupyOps() Model.Ops = CupyOps torch.set_default_tensor_type('torch.cuda.FloatTensor') else: print("GPU not available. Running on CPU.") pt_model = nn.Linear(nI, nO) optimizer = torch.optim.Adam(pt_model.parameters()) # noqa: F841 model = PyTorchWrapper(pt_model) X = Model.ops.xp.ones((length, nI), dtype="f") y = 1.0 / X for i in range(10): yh, get_dX = model.begin_update(X) dY = (yh - y) / len(y) dX = get_dX(dY) # noqa: F841
def main(depth=2, width=512, nb_epoch=30): prefer_gpu() torch.set_num_threads(1) train_data, dev_data, _ = datasets.mnist() train_X, train_y = Model.ops.unzip(train_data) dev_X, dev_y = Model.ops.unzip(dev_data) dev_y = to_categorical(dev_y) model = PyTorchWrapper( PyTorchFeedForward( depth=depth, width=width, input_size=train_X.shape[1], output_size=dev_y.shape[1], ) ) with model.begin_training(train_X, train_y, L2=1e-6) as (trainer, optimizer): epoch_loss = [0.0] def report_progress(): # with model.use_params(optimizer.averages): print(epoch_loss[-1], model.evaluate(dev_X, dev_y), trainer.dropout) epoch_loss.append(0.0) trainer.each_epoch.append(report_progress) trainer.nb_epoch = nb_epoch trainer.dropout = 0.3 trainer.batch_size = 128 trainer.dropout_decay = 0.0 train_X = model.ops.asarray(train_X, dtype="float32") y_onehot = to_categorical(train_y) for X, y in trainer.iterate(train_X, y_onehot): yh, backprop = model.begin_update(X, drop=trainer.dropout) loss = ((yh - y) ** 2.0).sum() / y.shape[0] backprop(yh - y, optimizer) epoch_loss[-1] += loss with model.use_params(optimizer.averages): print("Avg dev.: %.3f" % model.evaluate(dev_X, dev_y)) with open("out.pickle", "wb") as file_: pickle.dump(model, file_, -1)
class MultiHeadedAttention(Model): ''' This class implements multiheaded attention. It can be used for self attention or outer attention, depending on our needs. There is no left and right context width. We attend to the whole sentence and we take care of the masks to adjust appropriately. There are no actual different weight matrices for each head, but a bigger weight matrix for all heads. Going to bigger dimensions is the key to get the multiple heads. For the time being; key, query and value matrices are supposed to have the same length. ''' def __init__(self, nM=300, nH=6): Model.__init__(self) self.nH = nH self.nM = nM # model size: the length of the embeddings self.nD = nM // nH self.get_queries = with_reshape(Affine(nM, nM)) self.get_keys = with_reshape(Affine(nM, nM)) self.get_values = with_reshape(Affine(nM, nM)) self.get_output = with_reshape(Affine(nM, nM)) self._layers = [self.get_queries, self.get_keys, self.get_values, self.get_output] self._softmax = PyTorchWrapper(nn.Softmax(dim=-1)) ''' mask conf ''' i_grad = [1, 0] o_xp = None b_map = None ret_x = [0] conf = [i_grad, o_xp, b_map, ret_x] self._mask = PyTorchWrapper(PytorchMaskScores(), conf=conf) def begin_update(self, input, drop=0.1): # TESTED # Queries come from input[0], keys and values from input[1] if len(input) == 3: x0, mask, sentX = input sentY = sentX y0 = x0 self_attention = True else: self_attention = False x0, y0, mask, sentX, sentY = input ''' Shapes ''' # x0: nB, nL, nM # q0: nB, nL, nM # k0: nB, nL, nM # v0: nB, nL, nM # q1: nB, nH, nL, nD # k1: nB, nH, nL, nD # v1: nB, nH, nL, nD nB, nL, nD, nH = x0.shape[0], x0.shape[1], self.nD, self.nH q0, get_dx0 = self.get_queries.begin_update(x0) q1 = q0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3) k0, get_dy0_1 = self.get_keys.begin_update(y0) k1 = k0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3) v0, get_dy0_2 = self.get_values.begin_update(y0) v1 = v0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3) x1, get_dq1_dk1_dv1 = self.attn(q1, k1, v1, mask=mask, sentX=sentX, sentY=sentY, self_attn=self_attention) x2 = x1.transpose(0, 2, 1, 3).reshape((nB, nL, nH*nD)) x3, get_dx2 = self.get_output.begin_update(x2) def finish_update(dx3, sgd=None): dx2 = get_dx2(dx3, sgd=sgd) dx1 = dx2.reshape((nB, nL, nH, nD)).transpose(0, 2, 1, 3) dq1, dk1, dv1 = get_dq1_dk1_dv1(dx1) nM = nH * nD dq0 = dq1.transpose(0, 2, 1, 3).reshape((nB, nL, nM)) dk0 = dk1.transpose(0, 2, 1, 3).reshape((nB, nL, nM)) dv0 = dv1.transpose(0, 2, 1, 3).reshape((nB, nL, nM)) dy0 = get_dy0_2(dv0, sgd=sgd) + get_dy0_1(dk0, sgd=sgd) dx0 = get_dx0(dq0, sgd=sgd) if self_attention: return dx0 + dy0 else: return (dx0, dy0) return x3, finish_update def attn(self, Q, K, V, mask=None, sentX=None, sentY=None, self_attn=True): ''' Compute attention on (query, key, value) triplets. The similarity of the (Q, K) pairs are used to compute an attention matrix, which is used to rescale V. ''' # TESTED S0, bp_scaled_dp = self._scaled_dot_prod(Q, K) S1, bp_mask = self._mask.begin_update((S0, mask)) S2, bp_softmax = self._softmax.begin_update(S1) S3, bp_apply_attn = self._apply_attn(S2, V) def backprop_attn(dS3): ''' Attention three inputs, one output ''' dS2, dV = bp_apply_attn(dS3) dS1 = bp_softmax(dS2) dS0 = bp_mask(dS1) dQ, dK = bp_scaled_dp(dS0) return dQ, dK, dV return S3, backprop_attn def _scaled_dot_prod(self, Q0, K0): # TESTED # Q0: nB, nH, nL, nD # K0: nB, nH, nL, nD nB, nH, nL, nD = Q0.shape # Q1: nB*nH, nL, nD Q1 = Q0.reshape((nB*nH, nL, nD)) # K1: (nB*nH, nD, nL) K1 = K0.transpose(0, 1, 3, 2).reshape((nB*nH, nD, nL)) # K2: (nB*nH, nD, nL) K2 = (K1 / self.ops.xp.sqrt(self.nM)).astype("float32") # S0: (nB*nH, nL, nL) S0 = self.ops.xp.matmul(Q1, K2) # S1 shape: (nB, nH, nL, nL) S1 = S0.reshape((nB, nH, nL, nL)) def backprop_attn1(dS1): dS0 = dS1.reshape((nB*nH, nL, nL)) dQ1 = self.ops.xp.matmul(dS0, K2.transpose(0, 2, 1)) dK2 = self.ops.xp.matmul(Q1.transpose(0, 2, 1), dS0) dK1 = (dK2 / self.ops.xp.sqrt(self.nM)).astype("float32") dK0 = dK1.reshape((nB, nH, nD, nL)).transpose(0, 1, 3, 2) dQ0 = dQ1.reshape((nB, nH, nL, nD)) return dQ0, dK0 return S1, backprop_attn1 # def _mask(self, S0, mask): # S1 = S0.transpose(1, 0, 2, 3) # S2 = S1 - (1 - mask) * (1e9) # S3 = S2.transpose(1, 0, 2, 3) # # def backprop_attn2(dS3): # dS2 = dS3.transpose(1, 0, 2, 3) # dS1 = dS2 # dS0 = dS1.transpose(1, 0, 2, 3) # return dS0 # # return S3, backprop_attn2 def _apply_attn(self, S0, V0): ''' Multiplication with values ''' # TESTED # S0: (nB, nH, nL, nL) # VO: (nB, nH, nL, nD) # S1: (nB*nH, nL, nL) # V1: (nB*nH, nL, nD) # S2: (nB*nH, nL, nD) # S3: (nB, nH, nL, nD) nB, nH, nL, nL = S0.shape nD = V0.shape[-1] V1 = V0.reshape((nB*nH, nL, nD)) S1 = S0.reshape((nB*nH, nL, nL)) S2 = self.ops.xp.matmul(S1, V1) S3 = S2.reshape((nB, nH, nL, nD)) def backprop_attn4(dS3): dS2 = dS3.reshape((nB*nH, nL, nD)) # (nB*nH, nL, nD) @ (nB*nH, nL, nD).T --> (nB*nH, nL, nL) dS1 = self.ops.xp.matmul(dS2, V1.transpose(0, 2, 1)) # (nB*nH, nL, nL).T @ (nB*nH, nL, nD) --> (nB*nH, nL, nD) dV1 = self.ops.xp.matmul(S1.transpose(0, 2, 1), dS2) dS0 = dS1.reshape((nB, nH, nL, nL)) dV0 = dV1.reshape((nB, nH, nL, nD)) return dS0, dV0 return S3, backprop_attn4
class SparseAttention(Model): ''' This class implements multiheaded attention in steps, factorizing the attention matrix. ''' def __init__(self, nM=300, nH=6): self.nH = nH self.nM = nM # model size: the length of the embeddings self.nD = nM // nH self.get_queries = with_reshape(Affine(nM, nM)) self.get_keys = with_reshape(Affine(nM, nM)) self.get_values = with_reshape(Affine(nM, nM)) self.get_output = with_reshape(Affine(nM, nM)) self._layers = [self.get_queries, self.get_keys, self.get_values, self.get_output] self._softmax = PyTorchWrapper(nn.Softmax(dim=-1)) ''' mask conf ''' i_grad = [1, 0] o_xp = None b_map = None ret_x = [0] conf = [i_grad, o_xp, b_map, ret_x] self._mask = PyTorchWrapper(PytorchMaskScores(), conf=conf) def begin_update(self, input, drop=0.1): if len(input) == 3: x0, mask, sentX = input sentY = sentX y0 = x0 self_attention = True else: self_attention = False x0, y0, mask, sentX, sentY = input ''' Shapes ''' # x0: nB, nL, nM # q0: nB, nL, nM # k0: nB, nL, nM # v0: nB, nL, nM # q1: nB, nH, nL, nD # k1: nB, nH, nL, nD # v1: nB, nH, nL, nD nB, nL, nD, nH = x0.shape[0], x0.shape[1], self.nD, self.nH q0, get_dx0 = self.get_queries.begin_update(x0) q1 = q0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3) k0, get_dy0_1 = self.get_keys.begin_update(y0) k1 = k0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3) v0, get_dy0_2 = self.get_values.begin_update(y0) v1 = v0.reshape(nB, -1, self.nH, self.nD).transpose(0, 2, 1, 3) q2, get_dq1_dk1_dv1 = self.attn(q1, k1, v1, mask=mask & self.mask_floor(nB, nL), sentX=sentX, sentY=sentY, self_attn=self_attention) x1, get_dq2_dk1_dv1 = self.attn(q1, k1, v1, mask=mask & self.mask_repetitive(nB, nL), sentX=sentX, sentY=sentY, self_attn=self_attention) x2 = x1.transpose(0, 2, 1, 3).reshape((nB, nL, nH*nD)) x3, get_dx2 = self.get_output.begin_update(x2) def finish_update(dx3, sgd=None): dx2 = get_dx2(dx3, sgd=sgd) dx1 = dx2.reshape((nB, nL, nH, nD)).transpose(0, 2, 1, 3) dq2, dk11, dv11 = get_dq2_dk1_dv1(dx1) dq1, dk12, dv12 = get_dq1_dk1_dv1(dq2) dk1 = dk11 + dk12 dv1 = dv11 + dv12 nM = nH * nD dq0 = dq1.transpose(0, 2, 1, 3).reshape((nB, nL, nM)) dk0 = dk1.transpose(0, 2, 1, 3).reshape((nB, nL, nM)) dv0 = dv1.transpose(0, 2, 1, 3).reshape((nB, nL, nM)) dy0 = get_dy0_2(dv0, sgd=sgd) + get_dy0_1(dk0, sgd=sgd) dx0 = get_dx0(dq0, sgd=sgd) if self_attention: return dx0 + dy0 else: return (dx0, dy0) return x3, finish_update def attn(self, Q, K, V, mask=None, sentX=None, sentY=None, self_attn=True): ''' Compute attention on (query, key, value) triplets. The similarity of the (Q, K) pairs are used to compute an attention matrix, which is used to rescale V. ''' # TESTED S0, bp_scaled_dp = self._scaled_dot_prod(Q, K) S1, bp_mask = self._mask.begin_update((S0, mask)) S2, bp_softmax = self._softmax.begin_update(S1) S3, bp_apply_attn = self._apply_attn(S2, V) def backprop_attn(dS3): ''' Attention three inputs, one output ''' dS2, dV = bp_apply_attn(dS3) dS1 = bp_softmax(dS2) dS0 = bp_mask(dS1) dQ, dK = bp_scaled_dp(dS0) return dQ, dK, dV return S3, backprop_attn def _scaled_dot_prod(self, Q0, K0): # TESTED # Q0: nB, nH, nL, nD # K0: nB, nH, nL, nD nB, nH, nL, nD = Q0.shape # Q1: nB*nH, nL, nD Q1 = Q0.reshape((nB*nH, nL, nD)) # K1: (nB*nH, nD, nL) K1 = K0.transpose(0, 1, 3, 2).reshape((nB*nH, nD, nL)) # K2: (nB*nH, nD, nL) K2 = (K1 / self.ops.xp.sqrt(self.nM)).astype("float32") # S0: (nB*nH, nL, nL) S0 = self.ops.xp.matmul(Q1, K2) # S1 shape: (nB, nH, nL, nL) S1 = S0.reshape((nB, nH, nL, nL)) def backprop_attn1(dS1): dS0 = dS1.reshape((nB*nH, nL, nL)) dQ1 = self.ops.xp.matmul(dS0, K2.transpose(0, 2, 1)) dK2 = self.ops.xp.matmul(Q1.transpose(0, 2, 1), dS0) dK1 = (dK2 / self.ops.xp.sqrt(self.nM)).astype("float32") dK0 = dK1.reshape((nB, nH, nD, nL)).transpose(0, 1, 3, 2) dQ0 = dQ1.reshape((nB, nH, nL, nD)) return dQ0, dK0 return S1, backprop_attn1 def _apply_attn(self, S0, V0): ''' Multiplication with values ''' # TESTED # S0: (nB, nH, nL, nL) # VO: (nB, nH, nL, nD) # S1: (nB*nH, nL, nL) # V1: (nB*nH, nL, nD) # S2: (nB*nH, nL, nD) # S3: (nB, nH, nL, nD) nB, nH, nL, nL = S0.shape nD = V0.shape[-1] V1 = V0.reshape((nB*nH, nL, nD)) S1 = S0.reshape((nB*nH, nL, nL)) S2 = self.ops.xp.matmul(S1, V1) S3 = S2.reshape((nB, nH, nL, nD)) def backprop_attn4(dS3): dS2 = dS3.reshape((nB*nH, nL, nD)) # (nB*nH, nL, nD) @ (nB*nH, nL, nD).T --> (nB*nH, nL, nL) dS1 = self.ops.xp.matmul(dS2, V1.transpose(0, 2, 1)) # (nB*nH, nL, nL).T @ (nB*nH, nL, nD) --> (nB*nH, nL, nD) dV1 = self.ops.xp.matmul(S1.transpose(0, 2, 1), dS2) dS0 = dS1.reshape((nB, nH, nL, nL)) dV0 = dV1.reshape((nB, nH, nL, nD)) return dS0, dV0 return S3, backprop_attn4 def mask_floor(nB, nL): stride = math.ceil(math.sqrt(nL)) floor_mask = Model.ops.xp.zeros((nB, nL, nL), dtype=Model.ops.xp.uint8) for i in range(nL): lower = max(0, i - (i % stride)) higher = i + 1 floor_mask[:, i, lower:higher] = 1 return floor_mask def mask_repetitive(nB, nL): ''' Every stride tokens, mask one (independent of row) ''' stride = math.ceil(math.sqrt(nL)) repetitive_mask = Model.ops.xp.zeros((nB, nL, nL), dtype=Model.ops.xp.uint8) for j in range(nL): if ((j % stride) >= (stride - c)): if mode == 'left': repetitive_mask[:, j:, j] = 1 return repetitive_mask