def test_name_module_class():
    NamedConv2d = name_module_class(nn.Conv2d, [['N', 'C', 'H', 'W']])
    named_conv_2d = NamedConv2d(3, 5, kernel_size=3, padding=1)
    output_1 = named_conv_2d(x_input_1)
    assert output_1.names == ('N', 'C', 'H', 'W')
    output_2 = named_conv_2d(x_input_2)
    assert output_2.names == ('N', 'C', 'H', 'W')
    with pytest.raises(ValueError):
        assert named_conv_2d(x_input_3)
    with pytest.raises(ValueError):
        assert named_conv_2d(x_input_4)
def test_name_module_class_bilinear():
    x_1 = torch.rand((4, 2, 3), names=('N', 'X', 'C'))
    x_2 = torch.rand((4, 2, 3), names=('N', 'X', 'C'))
    x_3 = torch.rand((4, 2, 3), names=('N', 'C', 'X'))
    NamedBilinear = name_module_class(nn.Bilinear,
                                      [['N', '*', 'C'], ['N', '*', 'C']],
                                      ['N', '*', 'C'])
    named_bilinear = NamedBilinear(3, 3, 5)
    output = named_bilinear(x_1, x_2)
    assert output.names == ('N', 'X', 'C')
    with pytest.raises(ValueError):
        assert named_bilinear(x_1, x_3)
def test_name_module_class_loss_reduce_set_to_none():
    preds_1 = torch.rand((4, 4), names=('N', 'C'))
    preds_1_wrong = torch.rand((4, 4), names=('C', 'N'))
    labels_1 = torch.randint(4, (4, )).refine_names('N')
    preds_2 = torch.rand((4, 3, 5, 6), names=('N', 'C', 'X', 'Y'))
    labels_2 = torch.randint(3, (4, 5, 6)).refine_names('N', 'X', 'Y')
    NamedCELoss = name_module_class(nn.CrossEntropyLoss,
                                    [['N', 'C', '*'], ['N', '*']], ['N', '*'],
                                    reduce_option=True)
    named_celoss = NamedCELoss(reduction='none')
    output_1 = named_celoss(preds_1, labels_1)
    assert output_1.names == ('N', )
    output_2 = named_celoss(preds_2, labels_2)
    assert output_2.names == ('N', 'X', 'Y')
    with pytest.raises(ValueError):
        assert named_celoss(preds_1_wrong, labels_1)
def test_name_module_class_str():
    NamedConv2d = name_module_class(nn.Conv2d, [['N', 'C', 'H', 'W']])
    named_conv_2d = NamedConv2d(3, 5, kernel_size=3, padding=1)
    assert str(
        named_conv_2d
    ) == "NamedConv2d(3, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) NCHW -> NCHW"
def test_asterisk():
    x_input = torch.randint(2, (4, 3)).refine_names('N', 'X')
    NamedEmbedding = name_module_class(nn.Embedding, [['*']], ['*', 'C'])
    named_emb = NamedEmbedding(2, 5)
    output = named_emb(x_input)
    assert output.names == ('N', 'X', 'C')
def test_name_module_class_encapsulation():
    NamedConv2d = name_module_class(nn.Conv2d, [['N', 'C', 'H', 'W']])
    named_conv_2d = NamedConv2d(3, 5, kernel_size=3, padding=1)
    assert not hasattr(named_conv_2d, 'in_names')
def test_name_module_class_name_change():
    NamedConv2d = name_module_class(nn.Conv2d, [['N', 'C', 'H', 'W']],
                                    ['B', 'C', 'X', 'Y'])
    named_conv_2d = NamedConv2d(3, 5, kernel_size=3, padding=1)
    output_2 = named_conv_2d(x_input_2)
    assert output_2.names == ('B', 'C', 'X', 'Y')