def test_you_can_create_vertex_array_from_different_types() -> None: a = ConstantDouble(np.array([1., 2.])) b = ConstantDouble(np.array([3., 4.])) c = ConstantDouble(np.array([5., 6.])) context = KeanuContext() array = context.to_java_vertex_array([a, b, c]) d = Concatenation(0, [a, b, c]) assert np.allclose(d.get_value(), [1., 2., 3., 4., 5., 6.])
def test_you_can_name_a_sequence() -> None: x_label = "x" def factory(sequence_item): x = sequence_item.add_double_proxy_for(x_label) x_out = x * Const(2.0) x_out.set_label(x_label) sequence_item.add(x_out) x_start = ConstantDouble(1.0) initial_state: Optional[Dict[str, vertex_constructor_param_types]] = { x_label: x_start } sequence_name = "My_Awesome_Sequence" sequence = Sequence(count=2, factories=factory, initial_state=initial_state, name=sequence_name) sequence_item_contents = sequence.get_last_item().get_contents() x_output = sequence_item_contents.get(x_label) x_proxy = sequence_item_contents.get(Sequence.proxy_label_for(x_label)) assert x_output is not None assert x_proxy is not None assert x_output.get_value() == 4 assert x_proxy.get_value() == 2 x_output_label = x_output.get_label() assert x_output_label is not None assert re.match("My_Awesome_Sequence.Sequence_Item_1.\d+.x", x_output_label) is not None
def test_last_item_retrieved_correctly() -> None: x_label = "x" def factory(sequence_item): x = sequence_item.add_double_proxy_for(x_label) x_out = x * Const(2.0) x_out.set_label(x_label) sequence_item.add(x_out) x_start = ConstantDouble(1.0) initial_state: Optional[Dict[str, vertex_constructor_param_types]] = { x_label: x_start } sequence = Sequence(count=2, factories=factory, initial_state=initial_state) sequence_item_contents = sequence.get_last_item().get_contents() x_output = sequence_item_contents.get(x_label) x_proxy = sequence_item_contents.get(Sequence.proxy_label_for(x_label)) assert x_output is not None assert x_proxy is not None assert x_output.get_value() == 4 assert x_proxy.get_value() == 2
def test_you_can_get_a_bayes_net_from_a_sequence() -> None: x_label = "x" def factory(sequence_item): x = sequence_item.add_double_proxy_for(x_label) x_out = x * Const(2.0) x_out.set_label(x_label) sequence_item.add(x_out) x_start = ConstantDouble(1.0) initial_state: Optional[Dict[str, vertex_constructor_param_types]] = { x_label: x_start } sequence = Sequence(count=2, factories=factory, initial_state=initial_state) net = sequence.to_bayes_net() for item in sequence: vertex = item.get(x_label) full_label = vertex.get_label() assert full_label is not None assert net.get_vertex_by_label(full_label) is not None
def test_you_can_use_multiple_factories_to_build_sequences() -> None: x1_label = "x1" x2_label = "x2" x3_label = "x3" x4_label = "x4" two = ConstantDouble(2) half = ConstantDouble(0.5) def factory1(sequence_item): x1_input = sequence_item.add_double_proxy_for(x1_label) x2_input = sequence_item.add_double_proxy_for(x2_label) x1_output = x1_input * two x1_output.set_label(x1_label) x3_output = x2_input * two x3_output.set_label(x3_label) sequence_item.add(x1_output) sequence_item.add(x3_output) def factory2(sequence_item): x3_input = sequence_item.add_double_proxy_for(x3_label) x4_input = sequence_item.add_double_proxy_for(x4_label) x2_output = x3_input * half x2_output.set_label(x2_label) x4_output = x4_input * half x4_output.set_label(x4_label) sequence_item.add(x2_output) sequence_item.add(x4_output) x1_start = ConstantDouble(4) x2_start = ConstantDouble(4) x3_start = ConstantDouble(4) x4_start = ConstantDouble(4) initial_state: Optional[Dict[str, vertex_constructor_param_types]] = { x1_label: x1_start, x2_label: x2_start, x3_label: x3_start, x4_label: x4_start } factories = [factory1, factory2] sequence = Sequence(count=5, factories=factories, initial_state=initial_state) assert sequence.size() == 5 for item in sequence: __check_sequence_output_links_to_input( item, Sequence.proxy_label_for(x1_label), x1_label) __check_sequence_output_links_to_input( item, Sequence.proxy_label_for(x2_label), x3_label) __check_sequence_output_links_to_input( item, Sequence.proxy_label_for(x3_label), x2_label) __check_sequence_output_links_to_input( item, Sequence.proxy_label_for(x4_label), x4_label) __check_output_equals(sequence, x1_label, 128) __check_output_equals(sequence, x2_label, 2) __check_output_equals(sequence, x3_label, 8) __check_output_equals(sequence, x4_label, 0.125)