Esempio n. 1
0
def test_differentiable_repeat():
    op = DifferentiableMixedRepeat(
        [nn.Linear(8 if i == 0 else 16, 16) for i in range(4)],
        ValueChoice([0, 1], label='ccc') * 2 + 1, GumbelSoftmax(-1), {})
    op.resample({})
    assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
    sample = op.export({})
    assert 'ccc' in sample and sample['ccc'] in [0, 1]

    class TupleModule(nn.Module):
        def __init__(self, num):
            super().__init__()
            self.num = num

        def forward(self, *args, **kwargs):
            return torch.full((2, 3), self.num), torch.full(
                (3, 5), self.num), {
                    'a': 7,
                    'b': [self.num] * 11
                }

    class CustomSoftmax(nn.Softmax):
        def forward(self, *args, **kwargs):
            return [0.3, 0.3, 0.4]

    op = DifferentiableMixedRepeat([TupleModule(i + 1) for i in range(4)],
                                   ValueChoice([1, 2, 4], label='ccc'),
                                   CustomSoftmax(), {})
    op.resample({})
    res = op(None)
    assert len(res) == 3
    assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5
    assert res[2]['a'] == 7
    assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
Esempio n. 2
0
def test_differentiable_repeat():
    op = DifferentiableMixedRepeat(
        [nn.Linear(8 if i == 0 else 16, 16) for i in range(4)],
        ValueChoice([0, 1], label='ccc') * 2 + 1, GumbelSoftmax(-1), {})
    op.resample({})
    assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
    sample = op.export({})
    assert 'ccc' in sample and sample['ccc'] in [0, 1]
Esempio n. 3
0
def test_differentiable_layer_input():
    op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)),
                                   ('b', Linear(2, 3, bias=True))],
                                  nn.Parameter(torch.randn(2)), nn.Softmax(-1),
                                  'eee')
    assert op(torch.randn(4, 2)).size(-1) == 3
    assert op.export({})['eee'] in ['a', 'b']
    assert len(list(op.parameters())) == 3

    input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)),
                                     GumbelSoftmax(-1), 'ddd')
    assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
    assert len(input.export({})['ddd']) == 2
Esempio n. 4
0
def test_proxyless_layer_input():
    op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)),
                              ('b', Linear(2, 3, bias=True))],
                             nn.Parameter(torch.randn(2)), nn.Softmax(-1),
                             'eee')
    assert op.resample({})['eee'] in ['a', 'b']
    assert op(torch.randn(4, 2)).size(-1) == 3
    assert op.export({})['eee'] in ['a', 'b']
    assert len(list(op.parameters())) == 3

    input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)),
                                GumbelSoftmax(-1), 'ddd')
    assert input.resample({})['ddd'] in list(range(5))
    assert input([torch.randn(4, 2)
                  for _ in range(5)]).size() == torch.Size([4, 2])
    assert input.export({})['ddd'] in list(range(5))