コード例 #1
0
ファイル: test_rnns.py プロジェクト: xpertasks/matchbox
def test_rnn():
    def SimpleRNN(cell):
        def inner(x, h0):
            return simple_rnn(x, h0, cell)

        return inner

    mb_test(SimpleRNN(nn.RNNCell(2, 2)), (4, (True, 3), (False, 2)),
            (4, (False, 2)))
コード例 #2
0
ファイル: test_rnns.py プロジェクト: stjordanis/matchbox
def test_rnn():
    @batch
    def simple_rnn(x, h0, cell):
        h = h0
        for xt in x.unbind(1):
            h = cell(xt, h)
        return h
    def SimpleRNN(cell):
        def inner(x, h0):
            return simple_rnn(x, h0, cell)
        return inner
    mb_test(SimpleRNN(nn.RNNCell(2, 2)),
            (4, (True, 3), (False, 2)), (4, (False, 2)))
コード例 #3
0
ファイル: test_rnns.py プロジェクト: stjordanis/matchbox
def test_readme():
    class RNN(nn.Module):
        def __init__(self, size):
            super().__init__()
            self.cell = nn.RNNCell(size, size)
        @matchbox.batch
        def forward(self, x):
            h = x.new_zeros(x.size(0), x.size(-1))
            for xt in x.unbind(1):
                h = self.cell(xt, h)
            return h
    mb_test(RNN(1),
            (4, (True, 3), (False, 1)))
コード例 #4
0
ファイル: test_rnns.py プロジェクト: stjordanis/matchbox
def test_bilstm_class():
    class BiLSTMClass(nn.Module):
        def __init__(self, in_size, out_size):
            super().__init__()
            self.fcell = nn.LSTMCell(in_size, out_size)
            self.rcell = nn.LSTMCell(in_size, out_size)
        @batch
        def forward(self, x, h0=None, c0=None):
            hf = x.batch_zeros(x.size(-1)) if h0 is None else h0
            cf = x.batch_zeros(x.size(-1)) if c0 is None else c0
            for xt in x.unbind(1):
                state = self.fcell(xt, (hf, cf))
                hf = state[0]
                cf = state[1]
            hr = x.batch_zeros(x.size(-1)) if h0 is None else h0
            cr = x.batch_zeros(x.size(-1)) if c0 is None else c0
            for xt in reversed(x.unbind(1)):
                state = self.rcell(xt, (hr, cr))
                hr = state[0]
                cr = state[1]
            return hf, hr
    mb_test(BiLSTMClass(2, 2),
            (4, (True, 3), (False, 2)))
コード例 #5
0
ファイル: test_rnns.py プロジェクト: stjordanis/matchbox
def test_accum_birnn_class():
    class AccumBiRNNClass(nn.Module):
        def __init__(self, size):
            super().__init__()
            self.fwd = nn.RNNCell(size, size)
            self.bwd = nn.RNNCell(size, size)
        @batch
        def forward(self, x):
            h0 = x.batch_zeros(x.size(-1))
            h = h0
            fwd = []
            bwd = []
            for xt in x.unbind(1):
                h = self.fwd(xt, h)
                fwd.append(h)
            fwd = F.stack(fwd, 1)
            h = h0
            for xt in reversed(x.unbind(1)):
                h = self.bwd(xt, h)
                bwd.append(h)
            bwd = F.stack(reversed(bwd), 1)
            return F.cat((fwd, bwd), 2)
    mb_test(AccumBiRNNClass(1),
            (4, (True, 3), (False, 1)))
コード例 #6
0
def test_posenc():
    mb_test(lambda x: x + positional_encodings_like(x),
            (4, (True, 3), (False, 6)))
コード例 #7
0
def test_DecoderLayer():
    args = argparse.Namespace()
    args.__dict__.update(d_model=6, d_hidden=6, n_heads=3, drop_ratio=0)
    mb_test(DecoderLayer(args), (4, (True, 3), (False, 6)),
            (4, (True, 3), (False, 6)))
コード例 #8
0
def test_MultiHead():
    mb_test(MultiHead(Attention(6, 0, False), 6, 6, 3),
            (4, (True, 3), (False, 6)), (4, (True, 3), (False, 6)), 1)
    mb_test(MultiHead(Attention(6, 0, True), 6, 6, 3),
            (4, (True, 3), (False, 6)), 0, 0)
コード例 #9
0
def test_std():
    mb_test(lambda x: x.std(2), (4, (True, 3), (False, 2)))
コード例 #10
0
ファイル: test_rnns.py プロジェクト: xpertasks/matchbox
def test_rnn_class():
    mb_test(RNNClass(nn.RNNCell(2, 2)), (4, (True, 3), (False, 2)))
コード例 #11
0
ファイル: test_rnns.py プロジェクト: xpertasks/matchbox
def test_rnn_cell():
    mb_test(nn.RNNCell(2, 2), (4, (False, 2)), (4, (False, 2)))
コード例 #12
0
ファイル: test_rnns.py プロジェクト: xpertasks/matchbox
def test_accum_rnn_class():
    mb_test(AccumRNNClass(nn.RNNCell(2, 2), None), (4, (True, 3), (False, 2)))
    mb_test(AccumRNNClass(nn.RNNCell(2, 2), True), (4, (True, 3), (False, 2)))
コード例 #13
0
def test_causal_mask():
    mb_test(lambda x: x.causal_mask(2, 1).softmax() @ x,
            (4, (False, 3), (False, 3)))
    mb_test(lambda x: (x @ x.transpose(1, 2)).causal_mask(2, 1).softmax() @ x,
            (4, (True, 3), (False, 2)))
コード例 #14
0
def test_transpose():
    mb_test(lambda x: x.transpose(1, 2), (4, (True, 3), (False, 2)))
コード例 #15
0
def test_matmul():
    mb_test(lambda a, b: a @ b, (4, (True, 3), (False, 2)),
            (4, (False, 2), (True, 3)))
コード例 #16
0
def test_while():
    mb_test(while_loop, (4, ()))
コード例 #17
0
def test_if_noelse():
    mb_test(if_noelse, (4, ()))
コード例 #18
0
def test_LayerNorm():
    mb_test(LayerNorm(2), (4, (True, 3), (False, 2)))
コード例 #19
0
ファイル: test_rnns.py プロジェクト: xpertasks/matchbox
def test_accum_birnn_class():
    mb_test(AccumBiRNNClass(1), (4, (True, 3), (False, 1)))
コード例 #20
0
def test_FeedForward():
    mb_test(FeedForward(2, 3, 0), (4, (True, 3), (False, 2)))
コード例 #21
0
def test_ResidualBlock():
    mb_test(ResidualBlock(FeedForward(2, 3, 0), 2, 0),
            (4, (True, 3), (False, 2)))
コード例 #22
0
def test_Attention():
    mb_test(Attention(2, 0, False), (4, (True, 3), (False, 2)), 0, 0)
    mb_test(Attention(2, 0, False), (4, (True, 3), (False, 2)),
            (4, (True, 3), (False, 2)), 1)
    mb_test(Attention(2, 0, True), (4, (True, 3), (False, 2)), 0, 0)
コード例 #23
0
ファイル: test_rnns.py プロジェクト: xpertasks/matchbox
def test_lstm_class():
    mb_test(LSTMClass(2, 2), (4, (True, 3), (False, 2)))
コード例 #24
0
def test_mean():
    mb_test(lambda x: x.mean(2), (4, (True, 3), (False, 2)))