Exemple #1
0
def test_train_eval():
    m = nn.Cell()
    assert not m.training
    m.set_train()
    assert m.training
    m.set_train(False)
    assert not m.training
Exemple #2
0
def test_exceptions():
    """ test_exceptions """
    t = Tensor(np.ones([2, 3]))

    class ModError(nn.Cell):
        """ ModError definition """

        def __init__(self, tensor):
            self.weight = Parameter(tensor, name="weight")
            super(ModError, self).__init__()

        def construct(self, *inputs):
            pass

    with pytest.raises(AttributeError):
        ModError(t)

    class ModError1(nn.Cell):
        """ ModError1 definition """

        def __init__(self, tensor):
            super().__init__()
            self.weight = Parameter(tensor, name="weight")
            self.weight = None
            self.weight = ModA(tensor)

        def construct(self, *inputs):
            pass

    with pytest.raises(TypeError):
        ModError1(t)

    class ModError2(nn.Cell):
        """ ModError2 definition """

        def __init__(self, tensor):
            super().__init__()
            self.mod = ModA(tensor)
            self.mod = None
            self.mod = tensor

        def construct(self, *inputs):
            pass

    with pytest.raises(TypeError):
        ModError2(t)

    m = nn.Cell()
    with pytest.raises(NotImplementedError):
        m.construct()
Exemple #3
0
def test_add_attr():
    """ test_add_attr """
    ta = Tensor(np.ones([2, 3]))
    tb = Tensor(np.ones([1, 4]))
    p = Parameter(ta, name="weight")
    m = nn.Cell()
    m.insert_param_to_cell('weight', p)

    with pytest.raises(TypeError):
        m.insert_child_to_cell("network", p)

    with pytest.raises(KeyError):
        m.insert_param_to_cell('', p)
    with pytest.raises(KeyError):
        m.insert_param_to_cell('a.b', p)
    m.insert_param_to_cell('weight', p)
    with pytest.raises(KeyError):
        m.insert_child_to_cell('', ModA(ta))
    with pytest.raises(KeyError):
        m.insert_child_to_cell('a.b', ModB(tb))

    with pytest.raises(TypeError):
        m.insert_child_to_cell('buffer', tb)
    with pytest.raises(TypeError):
        m.insert_param_to_cell('w', ta)
    with pytest.raises(TypeError):
        m.insert_child_to_cell('m', p)

    class ModAddCellError(nn.Cell):
        """ ModAddCellError definition """

        def __init__(self, tensor):
            self.mod = ModA(tensor)
            super().__init__()

        def construct(self, *inputs):
            pass

    with pytest.raises(AttributeError):
        ModAddCellError(ta)