def test_rnn(): def SimpleRNN(cell): def inner(x, h0): return simple_rnn(x, h0, cell) return inner mb_test(SimpleRNN(nn.RNNCell(2, 2)), (4, (True, 3), (False, 2)), (4, (False, 2)))
def test_rnn(): @batch def simple_rnn(x, h0, cell): h = h0 for xt in x.unbind(1): h = cell(xt, h) return h def SimpleRNN(cell): def inner(x, h0): return simple_rnn(x, h0, cell) return inner mb_test(SimpleRNN(nn.RNNCell(2, 2)), (4, (True, 3), (False, 2)), (4, (False, 2)))
def test_readme(): class RNN(nn.Module): def __init__(self, size): super().__init__() self.cell = nn.RNNCell(size, size) @matchbox.batch def forward(self, x): h = x.new_zeros(x.size(0), x.size(-1)) for xt in x.unbind(1): h = self.cell(xt, h) return h mb_test(RNN(1), (4, (True, 3), (False, 1)))
def test_bilstm_class(): class BiLSTMClass(nn.Module): def __init__(self, in_size, out_size): super().__init__() self.fcell = nn.LSTMCell(in_size, out_size) self.rcell = nn.LSTMCell(in_size, out_size) @batch def forward(self, x, h0=None, c0=None): hf = x.batch_zeros(x.size(-1)) if h0 is None else h0 cf = x.batch_zeros(x.size(-1)) if c0 is None else c0 for xt in x.unbind(1): state = self.fcell(xt, (hf, cf)) hf = state[0] cf = state[1] hr = x.batch_zeros(x.size(-1)) if h0 is None else h0 cr = x.batch_zeros(x.size(-1)) if c0 is None else c0 for xt in reversed(x.unbind(1)): state = self.rcell(xt, (hr, cr)) hr = state[0] cr = state[1] return hf, hr mb_test(BiLSTMClass(2, 2), (4, (True, 3), (False, 2)))
def test_accum_birnn_class(): class AccumBiRNNClass(nn.Module): def __init__(self, size): super().__init__() self.fwd = nn.RNNCell(size, size) self.bwd = nn.RNNCell(size, size) @batch def forward(self, x): h0 = x.batch_zeros(x.size(-1)) h = h0 fwd = [] bwd = [] for xt in x.unbind(1): h = self.fwd(xt, h) fwd.append(h) fwd = F.stack(fwd, 1) h = h0 for xt in reversed(x.unbind(1)): h = self.bwd(xt, h) bwd.append(h) bwd = F.stack(reversed(bwd), 1) return F.cat((fwd, bwd), 2) mb_test(AccumBiRNNClass(1), (4, (True, 3), (False, 1)))
def test_posenc(): mb_test(lambda x: x + positional_encodings_like(x), (4, (True, 3), (False, 6)))
def test_DecoderLayer(): args = argparse.Namespace() args.__dict__.update(d_model=6, d_hidden=6, n_heads=3, drop_ratio=0) mb_test(DecoderLayer(args), (4, (True, 3), (False, 6)), (4, (True, 3), (False, 6)))
def test_MultiHead(): mb_test(MultiHead(Attention(6, 0, False), 6, 6, 3), (4, (True, 3), (False, 6)), (4, (True, 3), (False, 6)), 1) mb_test(MultiHead(Attention(6, 0, True), 6, 6, 3), (4, (True, 3), (False, 6)), 0, 0)
def test_std(): mb_test(lambda x: x.std(2), (4, (True, 3), (False, 2)))
def test_rnn_class(): mb_test(RNNClass(nn.RNNCell(2, 2)), (4, (True, 3), (False, 2)))
def test_rnn_cell(): mb_test(nn.RNNCell(2, 2), (4, (False, 2)), (4, (False, 2)))
def test_accum_rnn_class(): mb_test(AccumRNNClass(nn.RNNCell(2, 2), None), (4, (True, 3), (False, 2))) mb_test(AccumRNNClass(nn.RNNCell(2, 2), True), (4, (True, 3), (False, 2)))
def test_causal_mask(): mb_test(lambda x: x.causal_mask(2, 1).softmax() @ x, (4, (False, 3), (False, 3))) mb_test(lambda x: (x @ x.transpose(1, 2)).causal_mask(2, 1).softmax() @ x, (4, (True, 3), (False, 2)))
def test_transpose(): mb_test(lambda x: x.transpose(1, 2), (4, (True, 3), (False, 2)))
def test_matmul(): mb_test(lambda a, b: a @ b, (4, (True, 3), (False, 2)), (4, (False, 2), (True, 3)))
def test_while(): mb_test(while_loop, (4, ()))
def test_if_noelse(): mb_test(if_noelse, (4, ()))
def test_LayerNorm(): mb_test(LayerNorm(2), (4, (True, 3), (False, 2)))
def test_accum_birnn_class(): mb_test(AccumBiRNNClass(1), (4, (True, 3), (False, 1)))
def test_FeedForward(): mb_test(FeedForward(2, 3, 0), (4, (True, 3), (False, 2)))
def test_ResidualBlock(): mb_test(ResidualBlock(FeedForward(2, 3, 0), 2, 0), (4, (True, 3), (False, 2)))
def test_Attention(): mb_test(Attention(2, 0, False), (4, (True, 3), (False, 2)), 0, 0) mb_test(Attention(2, 0, False), (4, (True, 3), (False, 2)), (4, (True, 3), (False, 2)), 1) mb_test(Attention(2, 0, True), (4, (True, 3), (False, 2)), 0, 0)
def test_lstm_class(): mb_test(LSTMClass(2, 2), (4, (True, 3), (False, 2)))
def test_mean(): mb_test(lambda x: x.mean(2), (4, (True, 3), (False, 2)))