def test_check_model_validity_adder(): parent1 = wtfc._DynNodeData(index_id="ft", tensor=npr.rand(2, 2)) parent1.tensor_has_energy = False parent1.norm_axis = (0, 1) parent2 = wtfc._DynNodeData(index_id="ft", tensor=npr.rand(2, 2)) parent2.tensor_has_energy = True child = wtf.Adder(index_id="ft", tensor=npr.rand(2, 2)) child.new_parents(parent1, parent2) obs = wtf.PosObserver(index_id="ft", tensor=npr.rand(2, 2)) child.new_child(obs) with pytest.raises(ValueError, match=".*should all have energy.*"): child._check_model_validity()
def test_check_model_validity_dynnodedata(): parent = wtfc._DynNodeData(index_id="d", tensor=npr.rand(2)) parent.tensor_has_energy = False child1 = wtfo._Operator(index_id="d", tensor=npr.rand(1)) parent.new_child(child1, slice_for_child=slice(0, 1)) with pytest.raises( ValueError, match=".*Each coefficient of inner tensor should be seen by at least", ): parent._check_model_validity() parent = wtfc._DynNodeData(index_id="td", tensor=npr.rand(2, 2)) parent.tensor_has_energy = False parent.norm_axis = (1,) child1 = wtfo._Operator(index_id="td", tensor=npr.rand(2, 1)) parent.new_child(child1, slice_for_child=(..., slice(0, 1))) child2 = wtfo._Operator(index_id="td", tensor=npr.rand(2, 1)) parent.new_child(child2, slice_for_child=(..., slice(1, 2))) with pytest.raises( ValueError, match=".*its children should not see an incomplete piece of" ): parent._check_model_validity() parent = wtfc._DynNodeData(index_id="td", tensor=npr.rand(2, 2)) parent.tensor_has_energy = False parent.norm_axis = (1,) child1 = wtfo._Operator(index_id="t", tensor=npr.rand(1)) parent.new_child(child1, slice_for_child=(..., 0)) child2 = wtfo._Operator(index_id="td", tensor=npr.rand(1)) parent.new_child(child2, slice_for_child=(..., [1,])) with pytest.raises( ValueError, match=".*its children should not see an incomplete piece of" ): parent._check_model_validity() parent = wtfc._DynNodeData(index_id="d", tensor=npr.rand(3)) parent.tensor_has_energy = True child1 = wtfo._Operator(index_id="d", tensor=npr.rand(2)) parent.new_child(child1, slice_for_child=slice(0, 2)) child2 = wtfo._Operator(index_id="d", tensor=npr.rand(1)) parent.new_child(child2, slice_for_child=slice(0, 1)) with pytest.raises( ValueError, match=".*Each coefficient of inner tensor should be seen by at most", ): parent._check_model_validity()
def test_check_model_validity_integrator(): parent = wtfc._DynNodeData(index_id="dt", tensor=npr.rand(2, 4)) parent.tensor_has_energy = False parent.norm_axis = (0,) child = wtf.Integrator(index_id="dt", tensor=npr.rand(2, 4)) obs = wtf.PosObserver(index_id="dt", tensor=npr.rand(2, 4)) parent.new_child(child) child.new_child(obs) with pytest.raises(ValueError, match=".*or its last axis must be normalized.*"): child._check_model_validity()
def setup(has_energy1, has_energy2, norm_axis1, norm_axis2, concatenate): parent1 = wtfc._DynNodeData(index_id="f", tensor=npr.rand(2)) parent1.tensor_has_energy = has_energy1 parent1.norm_axis = norm_axis1 parent2 = wtfc._DynNodeData(index_id="f", tensor=npr.rand(2)) parent2.tensor_has_energy = has_energy2 parent2.norm_axis = norm_axis2 index_id_child = "f" if concatenate else "fd" multiplexer_idx = "f" if concatenate else None tensor_child = npr.rand(4) if concatenate else npr.rand(2, 2) multiplexer = wtf.Multiplexer( index_id=index_id_child, tensor=tensor_child, multiplexer_idx=multiplexer_idx, ) multiplexer.new_parents(parent1, parent2) obs = wtf.PosObserver(index_id=index_id_child, tensor=tensor_child) obs.new_parent(multiplexer) return multiplexer
def test_get_norm_for_children(): parent = wtfc._DynNodeData() parent.tensor_has_energy = False child = wtfo._Operator() sl = slice(None) # list of (parent_shape, parent_norm_axis, slice_for_child, shape_for_child, norm_axis_for_child) input_output_list = [ ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, sl), (2, 2, 2, 2), (0, 1, 2)), ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, slice(1)), (2, 2, 2, 1), (0, 1, 2)), ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, 0), (2, 2, 2), (0, 1, 2)), ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, 1, sl), (2, 4, 2), (0, 1)), ((2,) * 7, (1, 3, 4), (0, sl, 1, sl, sl, sl, sl), (2, 4, 4), (0, 1)), ((2,) * 7, (4,), (0, sl, 1, sl, sl, 1, sl), (4, 2, 2), (1,)), ((2, 3, 4, 2), (1, 2), (0, Ellipsis), (2, 6, 2), (0, 1)), ((2, 3, 4, 1), (1, 2), (0, Ellipsis, 0), (2, 6), (0, 1)), ((2,), (), Ellipsis, None, ()), ((2, 2, 2), (1,), (slice(1),), (1, 2, 2), (1,)), ] def setup_nodes(parent_shape, parent_norm_axis, slice_for_child, shape_for_child): parent.remove_child(child) parent.tensor = npr.rand(*parent_shape) parent.norm_axis = parent_norm_axis parent.new_child( child, slice_for_child=slice_for_child, shape_for_child=shape_for_child ) for elem in input_output_list: setup_nodes(*elem[:-1]) assert parent.get_norm_axis_for_children(child) == elem[-1] # expected errors # list of (parent_shape, parent_norm_axis, slice_for_child, shape_for_child, expected error) input_output_list = [ ((2, 2, 2), (2,), Ellipsis, (2, 4), "Invalid shape_for_child .*"), ((2, 2, 2), (1,), Ellipsis, (4, 2), "Invalid shape_for_child .*"), ( (2, 2, 2), (1,), (slice(1),), (2, 2), ".* 1-size dimension of the sliced tensor.*", ), ( (2, 2, 2), (2,), Ellipsis, (1, 4, 2), ".* 1-size dimensions in `shape_for_child`.*", ), ] for elem in input_output_list: setup_nodes(*elem[:-1]) with pytest.raises(ValueError, match=elem[-1]): parent.get_norm_axis_for_children(child)
def setup( has_energy1, has_energy2, index_id1, index_id2, index_id_child, norm_axis1, norm_axis2, ): parent1 = wtfc._DynNodeData(tensor=tensor1, index_id=index_id1) parent2 = wtfc._DynNodeData(tensor=tensor2, index_id=index_id2) child = wtfo.Multiplier( index_id=index_id_child, tensor=npr.rand(*(2,) * len(index_id_child)) ) obs = wtf.PosObserver(tensor=child.tensor) parent1.tensor_has_energy = has_energy1 parent2.tensor_has_energy = has_energy2 parent1.norm_axis = norm_axis1 parent2.norm_axis = norm_axis2 child.new_parents(parent1, parent2) child.new_child(obs) return child
def setup( has_energy1, has_energy2, index_id1, index_id2, index_id_child, norm_axis1, norm_axis2, conv_idx_ids, ): parent1 = wtfc._DynNodeData(tensor=tensor1, index_id=index_id1) parent1.tensor_has_energy = has_energy1 parent1.norm_axis = norm_axis1 parent2 = wtfc._DynNodeData(tensor=tensor2, index_id=index_id2) parent2.tensor_has_energy = has_energy2 parent2.norm_axis = norm_axis2 child = wtfo.Multiplier( index_id=index_id_child, tensor=tensor_child, conv_idx_ids=conv_idx_ids ) child.new_parents(parent1, parent2) obs = wtf.PosObserver(tensor=child.tensor) obs.new_parent(child) return child