Example #1
0
def test_differentiable_repeat():
    op = DifferentiableMixedRepeat(
        [nn.Linear(8 if i == 0 else 16, 16) for i in range(4)],
        ValueChoice([0, 1], label='ccc') * 2 + 1, GumbelSoftmax(-1), {})
    op.resample({})
    assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
    sample = op.export({})
    assert 'ccc' in sample and sample['ccc'] in [0, 1]

    class TupleModule(nn.Module):
        def __init__(self, num):
            super().__init__()
            self.num = num

        def forward(self, *args, **kwargs):
            return torch.full((2, 3), self.num), torch.full(
                (3, 5), self.num), {
                    'a': 7,
                    'b': [self.num] * 11
                }

    class CustomSoftmax(nn.Softmax):
        def forward(self, *args, **kwargs):
            return [0.3, 0.3, 0.4]

    op = DifferentiableMixedRepeat([TupleModule(i + 1) for i in range(4)],
                                   ValueChoice([1, 2, 4], label='ccc'),
                                   CustomSoftmax(), {})
    op.resample({})
    res = op(None)
    assert len(res) == 3
    assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5
    assert res[2]['a'] == 7
    assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
Example #2
0
def test_mixed_mhattn_batch_first():
    # batch_first is not supported for legacy pytorch versions
    # mark 1.7 because 1.7 is used on legacy pipeline

    mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'),
                                2,
                                kdim=(ValueChoice([3, 7], label='kdim')),
                                vdim=ValueChoice([5, 8], label='vdim'),
                                bias=False,
                                add_bias_kv=True,
                                batch_first=True)
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 4,
        'kdim': 7,
        'vdim': 8
    }, torch.randn(2, 7, 4), torch.randn(2, 7,
                                         7), torch.randn(2, 7,
                                                         8))[0].size(-1) == 4
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 8,
        'kdim': 3,
        'vdim': 5
    }, torch.randn(2, 7, 8), torch.randn(2, 7,
                                         3), torch.randn(2, 7,
                                                         5))[0].size(-1) == 8

    _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8),
                                                 torch.randn(1, 7, 7),
                                                 torch.randn(1, 7, 8))
Example #3
0
def test_valuechoice_utils():
    chosen = {"exp": 3, "add": 1}
    vc0 = ValueChoice([3, 4, 6], label='exp') * 2 + ValueChoice([0, 1],
                                                                label='add')

    assert evaluate_value_choice_with_dict(vc0, chosen) == 7
    vc = vc0 + ValueChoice([3, 4, 6], label='exp')
    assert evaluate_value_choice_with_dict(vc, chosen) == 10

    assert list(dedup_inner_choices([vc0, vc]).keys()) == ['exp', 'add']

    assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19]
    weights = dict(
        traverse_all_options(vc,
                             weights={
                                 'exp': [0.5, 0.3, 0.2],
                                 'add': [0.4, 0.6]
                             }))
    ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08),
                (19, 0.12)])
    assert len(weights) == len(ans)
    for value, weight in ans.items():
        assert abs(weight - weights[value]) < 1e-6

    assert evaluate_constant(
        ValueChoice([3, 4, 6], label='x') -
        ValueChoice([3, 4, 6], label='x')) == 0
    with pytest.raises(ValueError):
        evaluate_constant(ValueChoice([3, 4, 6]) - ValueChoice([3, 4, 6]))

    assert evaluate_constant(
        ValueChoice([3, 4, 6], label='x') * 2 /
        ValueChoice([3, 4, 6], label='x')) == 2
Example #4
0
def test_differentiable_valuechoice():
    orig_conv = Conv2d(3,
                       ValueChoice([3, 5, 7], label='456'),
                       kernel_size=ValueChoice([3, 5, 7], label='123'),
                       padding=ValueChoice([3, 5, 7], label='123') // 2)
    conv = MixedConv2d.mutate(
        orig_conv, 'dummy', {},
        {'mixed_op_sampling': MixedOpDifferentiablePolicy})
    assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7

    assert set(conv.export({}).keys()) == {'123', '456'}
Example #5
0
 def __init__(self):
     super().__init__()
     ch1 = ValueChoice([16, 32])
     kernel = ValueChoice([3, 5])
     self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
     self.batch_norm = nn.BatchNorm2d(ch1)
     self.conv2 = nn.Conv2d(ch1, 64, 3)
     self.dropout1 = LayerChoice(
         [nn.Dropout(.25), nn.Dropout(.5),
          nn.Dropout(.75)])
     self.fc = nn.Linear(64, 10)
Example #6
0
def test_differentiable_repeat():
    op = DifferentiableMixedRepeat(
        [nn.Linear(8 if i == 0 else 16, 16) for i in range(4)],
        ValueChoice([0, 1], label='ccc') * 2 + 1, GumbelSoftmax(-1), {})
    op.resample({})
    assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
    sample = op.export({})
    assert 'ccc' in sample and sample['ccc'] in [0, 1]
Example #7
0
def test_pathsampling_valuechoice():
    orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3)
    conv = MixedConv2d.mutate(orig_conv, 'dummy', {},
                              {'mixed_op_sampling': MixedOpPathSamplingPolicy})
    conv.resample(memo={'123': 5})
    assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5
    conv.resample(memo={'123': 7})
    assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7
    assert conv.export({})['123'] in [3, 5, 7]
Example #8
0
def test_pathsampling_repeat():
    op = PathSamplingRepeat(
        [nn.Linear(16, 16),
         nn.Linear(16, 8),
         nn.Linear(8, 4)], ValueChoice([1, 2, 3], label='ccc'))
    sample = op.resample({})
    assert sample['ccc'] in [1, 2, 3]
    for i in range(1, 4):
        op.resample({'ccc': i})
        out = op(torch.randn(2, 16))
        assert out.shape[1] == [16, 8, 4][i - 1]

    op = PathSamplingRepeat([nn.Linear(i + 1, i + 2) for i in range(7)],
                            2 * ValueChoice([1, 2, 3], label='ddd') + 1)
    sample = op.resample({})
    assert sample['ddd'] in [1, 2, 3]
    for i in range(1, 4):
        op.resample({'ddd': i})
        out = op(torch.randn(2, 1))
        assert out.shape[1] == (2 * i + 1) + 1
Example #9
0
def test_mixed_batchnorm2d():
    bn = BatchNorm2d(ValueChoice([32, 64], label='dim'))

    assert _mixed_operation_sampling_sanity_check(bn, {
        'dim': 32
    }, torch.randn(2, 32, 3, 3)).size(1) == 32
    assert _mixed_operation_sampling_sanity_check(bn, {
        'dim': 64
    }, torch.randn(2, 64, 3, 3)).size(1) == 64

    _mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
Example #10
0
def test_mixed_conv2d():
    conv = Conv2d(ValueChoice([3, 6, 9], label='in'),
                  ValueChoice([2, 4, 8], label='out') * 2, 1)
    assert _mixed_operation_sampling_sanity_check(conv, {
        'in': 3,
        'out': 4
    }, torch.randn(2, 3, 9, 9)).size(1) == 8
    _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))

    # stride
    conv = Conv2d(ValueChoice([3, 6, 9], label='in'),
                  ValueChoice([2, 4, 8], label='out'),
                  1,
                  stride=ValueChoice([1, 2], label='stride'))
    assert _mixed_operation_sampling_sanity_check(conv, {
        'in': 3,
        'stride': 2
    }, torch.randn(2, 3, 10, 10)).size(2) == 5
    assert _mixed_operation_sampling_sanity_check(conv, {
        'in': 3,
        'stride': 1
    }, torch.randn(2, 3, 10, 10)).size(2) == 10

    # groups, dw conv
    conv = Conv2d(ValueChoice([3, 6, 9], label='in'),
                  ValueChoice([3, 6, 9], label='in'),
                  1,
                  groups=ValueChoice([3, 6, 9], label='in'))
    assert _mixed_operation_sampling_sanity_check(conv, {
        'in': 6
    }, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])

    # make sure kernel is sliced correctly
    conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False)
    conv = MixedConv2d.mutate(conv, 'dummy', {},
                              {'mixed_op_sampling': MixedOpPathSamplingPolicy})
    with torch.no_grad():
        conv.weight.zero_()
        # only center is 1, must pick center to pass this test
        conv.weight[0, 0, 1, 1] = 1
    conv.resample({'k': 1})
    assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
Example #11
0
def test_mixed_linear():
    linear = Linear(ValueChoice([3, 6, 9], label='shared'),
                    ValueChoice([2, 4, 8]))
    _mixed_operation_sampling_sanity_check(linear, {'shared': 3},
                                           torch.randn(2, 3))
    _mixed_operation_sampling_sanity_check(linear, {'shared': 9},
                                           torch.randn(2, 9))
    _mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9))

    linear = Linear(ValueChoice([3, 6, 9], label='shared'),
                    ValueChoice([2, 4, 8]),
                    bias=False)
    _mixed_operation_sampling_sanity_check(linear, {'shared': 3},
                                           torch.randn(2, 3))

    with pytest.raises(TypeError):
        linear = Linear(ValueChoice([3, 6, 9], label='shared'),
                        ValueChoice([2, 4, 8]),
                        bias=ValueChoice([False, True]))
        _mixed_operation_sampling_sanity_check(linear, {'shared': 3},
                                               torch.randn(2, 3))
Example #12
0
def test_mixed_mhattn():
    mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4)

    assert _mixed_operation_sampling_sanity_check(
        mhattn, {'emb': 4}, torch.randn(7, 2, 4), torch.randn(7, 2, 4),
        torch.randn(7, 2, 4))[0].size(-1) == 4
    assert _mixed_operation_sampling_sanity_check(
        mhattn, {'emb': 8}, torch.randn(7, 2, 8), torch.randn(7, 2, 8),
        torch.randn(7, 2, 8))[0].size(-1) == 8

    _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8),
                                                 torch.randn(7, 2, 8),
                                                 torch.randn(7, 2, 8))

    mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'),
                                ValueChoice([2, 3, 4], label='heads'))
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 4,
        'heads': 2
    }, torch.randn(7, 2, 4), torch.randn(7, 2,
                                         4), torch.randn(7, 2,
                                                         4))[0].size(-1) == 4
    with pytest.raises(AssertionError, match='divisible'):
        assert _mixed_operation_sampling_sanity_check(mhattn, {
            'emb': 4,
            'heads': 3
        }, torch.randn(7, 2,
                       4), torch.randn(7, 2,
                                       4), torch.randn(7, 2,
                                                       4))[0].size(-1) == 4

    mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'),
                                4,
                                kdim=ValueChoice([5, 7], label='kdim'))
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 4,
        'kdim': 7
    }, torch.randn(7, 2, 4), torch.randn(7, 2,
                                         7), torch.randn(7, 2,
                                                         4))[0].size(-1) == 4
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 8,
        'kdim': 5
    }, torch.randn(7, 2, 8), torch.randn(7, 2,
                                         5), torch.randn(7, 2,
                                                         8))[0].size(-1) == 8

    mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'),
                                4,
                                vdim=ValueChoice([5, 8], label='vdim'))
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 4,
        'vdim': 8
    }, torch.randn(7, 2, 4), torch.randn(7, 2,
                                         4), torch.randn(7, 2,
                                                         8))[0].size(-1) == 4
    assert _mixed_operation_sampling_sanity_check(mhattn, {
        'emb': 8,
        'vdim': 5
    }, torch.randn(7, 2, 8), torch.randn(7, 2,
                                         8), torch.randn(7, 2,
                                                         5))[0].size(-1) == 8

    _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8),
                                                 torch.randn(5, 3, 8),
                                                 torch.randn(5, 3, 8))
Example #13
0
 def __init__(self, head_count):
     super().__init__()
     embed_dim = ValueChoice(candidates=[32, 64])
     self.linear1 = nn.Linear(128, embed_dim)
     self.mhatt = nn.MultiheadAttention(embed_dim, head_count)
     self.linear2 = nn.Linear(embed_dim, 1)