Ejemplo n.º 1
0
def test_check_model_validity_adder():
    parent1 = wtfc._DynNodeData(index_id="ft", tensor=npr.rand(2, 2))
    parent1.tensor_has_energy = False
    parent1.norm_axis = (0, 1)
    parent2 = wtfc._DynNodeData(index_id="ft", tensor=npr.rand(2, 2))
    parent2.tensor_has_energy = True
    child = wtf.Adder(index_id="ft", tensor=npr.rand(2, 2))
    child.new_parents(parent1, parent2)
    obs = wtf.PosObserver(index_id="ft", tensor=npr.rand(2, 2))
    child.new_child(obs)
    with pytest.raises(ValueError, match=".*should all have energy.*"):
        child._check_model_validity()
Ejemplo n.º 2
0
def test_check_model_validity_dynnodedata():
    parent = wtfc._DynNodeData(index_id="d", tensor=npr.rand(2))
    parent.tensor_has_energy = False
    child1 = wtfo._Operator(index_id="d", tensor=npr.rand(1))
    parent.new_child(child1, slice_for_child=slice(0, 1))
    with pytest.raises(
        ValueError,
        match=".*Each coefficient of inner tensor should be seen by at least",
    ):
        parent._check_model_validity()

    parent = wtfc._DynNodeData(index_id="td", tensor=npr.rand(2, 2))
    parent.tensor_has_energy = False
    parent.norm_axis = (1,)
    child1 = wtfo._Operator(index_id="td", tensor=npr.rand(2, 1))
    parent.new_child(child1, slice_for_child=(..., slice(0, 1)))
    child2 = wtfo._Operator(index_id="td", tensor=npr.rand(2, 1))
    parent.new_child(child2, slice_for_child=(..., slice(1, 2)))
    with pytest.raises(
        ValueError, match=".*its children should not see an incomplete piece of"
    ):
        parent._check_model_validity()

    parent = wtfc._DynNodeData(index_id="td", tensor=npr.rand(2, 2))
    parent.tensor_has_energy = False
    parent.norm_axis = (1,)
    child1 = wtfo._Operator(index_id="t", tensor=npr.rand(1))
    parent.new_child(child1, slice_for_child=(..., 0))
    child2 = wtfo._Operator(index_id="td", tensor=npr.rand(1))
    parent.new_child(child2, slice_for_child=(..., [1,]))
    with pytest.raises(
        ValueError, match=".*its children should not see an incomplete piece of"
    ):
        parent._check_model_validity()

    parent = wtfc._DynNodeData(index_id="d", tensor=npr.rand(3))
    parent.tensor_has_energy = True
    child1 = wtfo._Operator(index_id="d", tensor=npr.rand(2))
    parent.new_child(child1, slice_for_child=slice(0, 2))
    child2 = wtfo._Operator(index_id="d", tensor=npr.rand(1))
    parent.new_child(child2, slice_for_child=slice(0, 1))
    with pytest.raises(
        ValueError,
        match=".*Each coefficient of inner tensor should be seen by at most",
    ):
        parent._check_model_validity()
Ejemplo n.º 3
0
def test_check_model_validity_integrator():
    parent = wtfc._DynNodeData(index_id="dt", tensor=npr.rand(2, 4))
    parent.tensor_has_energy = False
    parent.norm_axis = (0,)
    child = wtf.Integrator(index_id="dt", tensor=npr.rand(2, 4))
    obs = wtf.PosObserver(index_id="dt", tensor=npr.rand(2, 4))
    parent.new_child(child)
    child.new_child(obs)
    with pytest.raises(ValueError, match=".*or its last axis must be normalized.*"):
        child._check_model_validity()
Ejemplo n.º 4
0
 def setup(has_energy1, has_energy2, norm_axis1, norm_axis2, concatenate):
     parent1 = wtfc._DynNodeData(index_id="f", tensor=npr.rand(2))
     parent1.tensor_has_energy = has_energy1
     parent1.norm_axis = norm_axis1
     parent2 = wtfc._DynNodeData(index_id="f", tensor=npr.rand(2))
     parent2.tensor_has_energy = has_energy2
     parent2.norm_axis = norm_axis2
     index_id_child = "f" if concatenate else "fd"
     multiplexer_idx = "f" if concatenate else None
     tensor_child = npr.rand(4) if concatenate else npr.rand(2, 2)
     multiplexer = wtf.Multiplexer(
         index_id=index_id_child,
         tensor=tensor_child,
         multiplexer_idx=multiplexer_idx,
     )
     multiplexer.new_parents(parent1, parent2)
     obs = wtf.PosObserver(index_id=index_id_child, tensor=tensor_child)
     obs.new_parent(multiplexer)
     return multiplexer
Ejemplo n.º 5
0
def test_get_norm_for_children():
    parent = wtfc._DynNodeData()
    parent.tensor_has_energy = False
    child = wtfo._Operator()
    sl = slice(None)
    # list of (parent_shape, parent_norm_axis, slice_for_child, shape_for_child, norm_axis_for_child)
    input_output_list = [
        ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, sl), (2, 2, 2, 2), (0, 1, 2)),
        ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, slice(1)), (2, 2, 2, 1), (0, 1, 2)),
        ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, 0), (2, 2, 2), (0, 1, 2)),
        ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, sl), (2, 4, 2), (0, 1)),
        ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, sl, sl), (2, 4, 4), (0, 1)),
        ((2,) * 7, (4,), (0, sl, 1, sl, sl, 1, sl), (4, 2, 2), (1,)),
        ((2, 3, 4, 2), (1, 2), (0, Ellipsis), (2, 6, 2), (0, 1)),
        ((2, 3, 4, 1), (1, 2), (0, Ellipsis, 0), (2, 6), (0, 1)),
        ((2,), (), Ellipsis, None, ()),
        ((2, 2, 2), (1,), (slice(1),), (1, 2, 2), (1,)),
    ]

    def setup_nodes(parent_shape, parent_norm_axis, slice_for_child, shape_for_child):
        parent.remove_child(child)
        parent.tensor = npr.rand(*parent_shape)
        parent.norm_axis = parent_norm_axis
        parent.new_child(
            child, slice_for_child=slice_for_child, shape_for_child=shape_for_child
        )

    for elem in input_output_list:
        setup_nodes(*elem[:-1])
        assert parent.get_norm_axis_for_children(child) == elem[-1]

    # expected errors
    # list of (parent_shape, parent_norm_axis, slice_for_child, shape_for_child, expected error)
    input_output_list = [
        ((2, 2, 2), (2,), Ellipsis, (2, 4), "Invalid shape_for_child .*"),
        ((2, 2, 2), (1,), Ellipsis, (4, 2), "Invalid shape_for_child .*"),
        (
            (2, 2, 2),
            (1,),
            (slice(1),),
            (2, 2),
            ".* 1-size dimension of the sliced tensor.*",
        ),
        (
            (2, 2, 2),
            (2,),
            Ellipsis,
            (1, 4, 2),
            ".* 1-size dimensions in `shape_for_child`.*",
        ),
    ]
    for elem in input_output_list:
        setup_nodes(*elem[:-1])
        with pytest.raises(ValueError, match=elem[-1]):
            parent.get_norm_axis_for_children(child)
Ejemplo n.º 6
0
 def setup(
     has_energy1,
     has_energy2,
     index_id1,
     index_id2,
     index_id_child,
     norm_axis1,
     norm_axis2,
 ):
     parent1 = wtfc._DynNodeData(tensor=tensor1, index_id=index_id1)
     parent2 = wtfc._DynNodeData(tensor=tensor2, index_id=index_id2)
     child = wtfo.Multiplier(
         index_id=index_id_child, tensor=npr.rand(*(2,) * len(index_id_child))
     )
     obs = wtf.PosObserver(tensor=child.tensor)
     parent1.tensor_has_energy = has_energy1
     parent2.tensor_has_energy = has_energy2
     parent1.norm_axis = norm_axis1
     parent2.norm_axis = norm_axis2
     child.new_parents(parent1, parent2)
     child.new_child(obs)
     return child
Ejemplo n.º 7
0
 def setup(
     has_energy1,
     has_energy2,
     index_id1,
     index_id2,
     index_id_child,
     norm_axis1,
     norm_axis2,
     conv_idx_ids,
 ):
     parent1 = wtfc._DynNodeData(tensor=tensor1, index_id=index_id1)
     parent1.tensor_has_energy = has_energy1
     parent1.norm_axis = norm_axis1
     parent2 = wtfc._DynNodeData(tensor=tensor2, index_id=index_id2)
     parent2.tensor_has_energy = has_energy2
     parent2.norm_axis = norm_axis2
     child = wtfo.Multiplier(
         index_id=index_id_child, tensor=tensor_child, conv_idx_ids=conv_idx_ids
     )
     child.new_parents(parent1, parent2)
     obs = wtf.PosObserver(tensor=child.tensor)
     obs.new_parent(child)
     return child