def test_differentiable_repeat(): op = DifferentiableMixedRepeat( [nn.Linear(8 if i == 0 else 16, 16) for i in range(4)], ValueChoice([0, 1], label='ccc') * 2 + 1, GumbelSoftmax(-1), {}) op.resample({}) assert op(torch.randn(2, 8)).size() == torch.Size([2, 16]) sample = op.export({}) assert 'ccc' in sample and sample['ccc'] in [0, 1] class TupleModule(nn.Module): def __init__(self, num): super().__init__() self.num = num def forward(self, *args, **kwargs): return torch.full((2, 3), self.num), torch.full( (3, 5), self.num), { 'a': 7, 'b': [self.num] * 11 } class CustomSoftmax(nn.Softmax): def forward(self, *args, **kwargs): return [0.3, 0.3, 0.4] op = DifferentiableMixedRepeat([TupleModule(i + 1) for i in range(4)], ValueChoice([1, 2, 4], label='ccc'), CustomSoftmax(), {}) op.resample({}) res = op(None) assert len(res) == 3 assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5 assert res[2]['a'] == 7 assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
def test_mixed_mhattn_batch_first(): # batch_first is not supported for legacy pytorch versions # mark 1.7 because 1.7 is used on legacy pipeline mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 2, kdim=(ValueChoice([3, 7], label='kdim')), vdim=ValueChoice([5, 8], label='vdim'), bias=False, add_bias_kv=True, batch_first=True) assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 4, 'kdim': 7, 'vdim': 8 }, torch.randn(2, 7, 4), torch.randn(2, 7, 7), torch.randn(2, 7, 8))[0].size(-1) == 4 assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 8, 'kdim': 3, 'vdim': 5 }, torch.randn(2, 7, 8), torch.randn(2, 7, 3), torch.randn(2, 7, 5))[0].size(-1) == 8 _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8), torch.randn(1, 7, 7), torch.randn(1, 7, 8))
def test_valuechoice_utils(): chosen = {"exp": 3, "add": 1} vc0 = ValueChoice([3, 4, 6], label='exp') * 2 + ValueChoice([0, 1], label='add') assert evaluate_value_choice_with_dict(vc0, chosen) == 7 vc = vc0 + ValueChoice([3, 4, 6], label='exp') assert evaluate_value_choice_with_dict(vc, chosen) == 10 assert list(dedup_inner_choices([vc0, vc]).keys()) == ['exp', 'add'] assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19] weights = dict( traverse_all_options(vc, weights={ 'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6] })) ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08), (19, 0.12)]) assert len(weights) == len(ans) for value, weight in ans.items(): assert abs(weight - weights[value]) < 1e-6 assert evaluate_constant( ValueChoice([3, 4, 6], label='x') - ValueChoice([3, 4, 6], label='x')) == 0 with pytest.raises(ValueError): evaluate_constant(ValueChoice([3, 4, 6]) - ValueChoice([3, 4, 6])) assert evaluate_constant( ValueChoice([3, 4, 6], label='x') * 2 / ValueChoice([3, 4, 6], label='x')) == 2
def test_differentiable_valuechoice(): orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='456'), kernel_size=ValueChoice([3, 5, 7], label='123'), padding=ValueChoice([3, 5, 7], label='123') // 2) conv = MixedConv2d.mutate( orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy}) assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7 assert set(conv.export({}).keys()) == {'123', '456'}
def __init__(self): super().__init__() ch1 = ValueChoice([16, 32]) kernel = ValueChoice([3, 5]) self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2) self.batch_norm = nn.BatchNorm2d(ch1) self.conv2 = nn.Conv2d(ch1, 64, 3) self.dropout1 = LayerChoice( [nn.Dropout(.25), nn.Dropout(.5), nn.Dropout(.75)]) self.fc = nn.Linear(64, 10)
def test_differentiable_repeat(): op = DifferentiableMixedRepeat( [nn.Linear(8 if i == 0 else 16, 16) for i in range(4)], ValueChoice([0, 1], label='ccc') * 2 + 1, GumbelSoftmax(-1), {}) op.resample({}) assert op(torch.randn(2, 8)).size() == torch.Size([2, 16]) sample = op.export({}) assert 'ccc' in sample and sample['ccc'] in [0, 1]
def test_pathsampling_valuechoice(): orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3) conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) conv.resample(memo={'123': 5}) assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5 conv.resample(memo={'123': 7}) assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7 assert conv.export({})['123'] in [3, 5, 7]
def test_pathsampling_repeat(): op = PathSamplingRepeat( [nn.Linear(16, 16), nn.Linear(16, 8), nn.Linear(8, 4)], ValueChoice([1, 2, 3], label='ccc')) sample = op.resample({}) assert sample['ccc'] in [1, 2, 3] for i in range(1, 4): op.resample({'ccc': i}) out = op(torch.randn(2, 16)) assert out.shape[1] == [16, 8, 4][i - 1] op = PathSamplingRepeat([nn.Linear(i + 1, i + 2) for i in range(7)], 2 * ValueChoice([1, 2, 3], label='ddd') + 1) sample = op.resample({}) assert sample['ddd'] in [1, 2, 3] for i in range(1, 4): op.resample({'ddd': i}) out = op(torch.randn(2, 1)) assert out.shape[1] == (2 * i + 1) + 1
def test_mixed_batchnorm2d(): bn = BatchNorm2d(ValueChoice([32, 64], label='dim')) assert _mixed_operation_sampling_sanity_check(bn, { 'dim': 32 }, torch.randn(2, 32, 3, 3)).size(1) == 32 assert _mixed_operation_sampling_sanity_check(bn, { 'dim': 64 }, torch.randn(2, 64, 3, 3)).size(1) == 64 _mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
def test_mixed_conv2d(): conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out') * 2, 1) assert _mixed_operation_sampling_sanity_check(conv, { 'in': 3, 'out': 4 }, torch.randn(2, 3, 9, 9)).size(1) == 8 _mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3)) # stride conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out'), 1, stride=ValueChoice([1, 2], label='stride')) assert _mixed_operation_sampling_sanity_check(conv, { 'in': 3, 'stride': 2 }, torch.randn(2, 3, 10, 10)).size(2) == 5 assert _mixed_operation_sampling_sanity_check(conv, { 'in': 3, 'stride': 1 }, torch.randn(2, 3, 10, 10)).size(2) == 10 # groups, dw conv conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in')) assert _mixed_operation_sampling_sanity_check(conv, { 'in': 6 }, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10]) # make sure kernel is sliced correctly conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False) conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) with torch.no_grad(): conv.weight.zero_() # only center is 1, must pick center to pass this test conv.weight[0, 0, 1, 1] = 1 conv.resample({'k': 1}) assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
def test_mixed_linear(): linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8])) _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3)) _mixed_operation_sampling_sanity_check(linear, {'shared': 9}, torch.randn(2, 9)) _mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9)) linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=False) _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3)) with pytest.raises(TypeError): linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=ValueChoice([False, True])) _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
def test_mixed_mhattn(): mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4) assert _mixed_operation_sampling_sanity_check( mhattn, {'emb': 4}, torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4 assert _mixed_operation_sampling_sanity_check( mhattn, {'emb': 8}, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))[0].size(-1) == 8 _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8)) mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), ValueChoice([2, 3, 4], label='heads')) assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 4, 'heads': 2 }, torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4 with pytest.raises(AssertionError, match='divisible'): assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 4, 'heads': 3 }, torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4 mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4, kdim=ValueChoice([5, 7], label='kdim')) assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 4, 'kdim': 7 }, torch.randn(7, 2, 4), torch.randn(7, 2, 7), torch.randn(7, 2, 4))[0].size(-1) == 4 assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 8, 'kdim': 5 }, torch.randn(7, 2, 8), torch.randn(7, 2, 5), torch.randn(7, 2, 8))[0].size(-1) == 8 mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4, vdim=ValueChoice([5, 8], label='vdim')) assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 4, 'vdim': 8 }, torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 8))[0].size(-1) == 4 assert _mixed_operation_sampling_sanity_check(mhattn, { 'emb': 8, 'vdim': 5 }, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 5))[0].size(-1) == 8 _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8))
def __init__(self, head_count): super().__init__() embed_dim = ValueChoice(candidates=[32, 64]) self.linear1 = nn.Linear(128, embed_dim) self.mhatt = nn.MultiheadAttention(embed_dim, head_count) self.linear2 = nn.Linear(embed_dim, 1)