Beispiel #1
0
def test_Transformer():
    args = argparse.Namespace()
    B, T, C, V = 4, 3, 6, 5
    args.__dict__.update(d_model=C,
                         d_hidden=C,
                         n_heads=3,
                         drop_ratio=0,
                         n_layers=2,
                         length_ratio=1.5)
    field = MaskedBatchField()
    field.vocab = list(range(V))
    xs = [
        Variable(torch.LongTensor(1, random.randint(1, T)).random_(V))
        for i in range(B)
    ]
    ys = [
        Variable(torch.LongTensor(1, random.randint(2, T)).random_(V))
        for i in range(B)
    ]
    xb = MaskedBatch.fromlist(xs, (True, ))
    yb = MaskedBatch.fromlist(ys, (True, ))
    model = Transformer(field, field, args)
    mb_assert(model, (xs, ys), (xb, yb), B)

    def loss(x, y):
        b = namedtuple('_batch', ('src', 'trg'))(x, y)
        return model.loss(b, reduce=False)

    mb_assert(loss, (xs, ys), (xb, yb), B)
Beispiel #2
0
def test_embedding():
    xs = [
        Variable(torch.LongTensor(1, random.randint(1, 3)).random_(5))
        for i in range(4)
    ]
    W = Variable(torch.rand(5, 2))
    xb = MaskedBatch.fromlist(xs, (True, ))
    mb_assert(F.embedding, (xs, W), (xb, W), 4)
Beispiel #3
0
def mb_rand(*dims):
    dims = [dim for dim in dims if dim != ()]
    xs = [
        Variable(
            torch.rand(
                1,
                *(random.randint(1, size) if b else size
                  for b, size in dims[1:]))) for i in range(dims[0])
    ]
    xb = MaskedBatch.fromlist(xs, tuple(b for b, d in dims[1:]))
    return xs, xb
Beispiel #4
0
def test_Encoder():
    args = argparse.Namespace()
    args.__dict__.update(d_model=6,
                         d_hidden=6,
                         n_heads=3,
                         drop_ratio=0,
                         n_layers=2)
    field = MaskedBatchField()
    field.out = nn.Linear(args.d_model, 5)
    xs = [
        Variable(torch.LongTensor(1, random.randint(1, 3)).random_(5))
        for i in range(4)
    ]
    xb = MaskedBatch.fromlist(xs, (True, ))
    mb_assert(Encoder(field, args), (xs, ), (xb, ), 4)