Beispiel #1
0
def test_tiled_backward(in_splits, out_splits, bias, in_f, out_f):
    base = torch.nn.Linear(in_f, out_f, bias=bias)
    test = TiledLinear(in_f,
                       out_f,
                       bias=bias,
                       init_linear=copy.deepcopy(base),
                       out_splits=out_splits,
                       in_splits=in_splits)

    inp = torch.rand(in_f)

    base_out = base(copy.deepcopy(inp))
    test_out = test(copy.deepcopy(inp))
    assert torch.allclose(base_out, test_out, rtol=1e-4)

    base_out.sum().backward()
    test_out.sum().backward()

    # compare grads
    for row in range(out_splits):
        rstart = test.out_parts[row]
        rstop = test.out_parts[row + 1]

        for col in range(in_splits):
            cstart = test.in_parts[col]
            cstop = test.in_parts[col + 1]

            local = test.linears[row][col]
            base_grad = base.weight.grad[rstart:rstop, cstart:cstop]
            assert torch.allclose(base_grad, local.weight.grad, rtol=1e-4)

            if local.bias is not None:
                base_grad = base.bias.grad[rstart:rstop]
                assert torch.allclose(base_grad, local.bias.grad, rtol=1e-4)
Beispiel #2
0
def test_tiled_forward(in_splits, out_splits, bias, in_f, out_f):
    base = torch.nn.Linear(in_f, out_f, bias=bias)
    test = TiledLinear(in_f,
                       out_f,
                       bias=bias,
                       init_linear=copy.deepcopy(base),
                       out_splits=out_splits,
                       in_splits=in_splits)

    inp = torch.rand(in_f)

    base_out = base(copy.deepcopy(inp))
    test_out = test(copy.deepcopy(inp))

    assert torch.allclose(base_out, test_out, rtol=1e-4)
Beispiel #3
0
def test_tiled_init(in_splits, out_splits):
    in_f = 32
    out_f = 40
    base = torch.nn.Linear(in_f, out_f, bias=True)
    l = TiledLinear(in_f,
                    out_f,
                    bias=True,
                    init_linear=copy.deepcopy(base),
                    out_splits=out_splits,
                    in_splits=in_splits)

    for out_id in range(out_splits):
        for in_id in range(in_splits):
            local_l = l.linears[out_id][in_id]
            assert isinstance(local_l, torch.nn.Linear)

            rstart = l.out_parts[out_id]
            rstop = l.out_parts[out_id + 1]
            cstart = l.in_parts[in_id]
            cstop = l.in_parts[in_id + 1]

            local_out = rstop - rstart
            local_in = cstop - cstart
            assert local_l.weight.size(
            )[1] == local_in, f'local[{out_id}][{in_id}].size {local_l.weight.size()}'
            assert local_l.weight.size()[0] == local_out

            test = base.weight[rstart:rstop, cstart:cstop]

            assert local_l.weight.size() == test.size()
            assert torch.equal(local_l.weight.data, test.data)

            if in_id == in_splits - 1:
                assert local_l.bias is not None
                assert local_l.bias.size()[0] == local_out
            else:
                assert local_l.bias is None
Beispiel #4
0
def test_tiled_baddim(in_splits, out_splits):
    dim = 32
    with pytest.raises(RuntimeError):
        l = TiledLinear(dim, dim, out_splits=out_splits, in_splits=in_splits)