def test_demultiplexer(self): # Test demultiplexer with scalar outputs. demux = Demultiplexer(size=4) context = demux.CreateDefaultContext() self.assertEqual(demux.get_num_input_ports(), 1) self.assertEqual(demux.get_num_output_ports(), 4) input_vec = np.array([1., 2., 3., 4.]) context.FixInputPort(0, BasicVector(input_vec)) output = demux.AllocateOutput() demux.CalcOutput(context, output) for i in range(4): self.assertTrue( np.allclose( output.get_vector_data(i).get_value(), input_vec[i])) # Test demultiplexer with vector outputs. demux = Demultiplexer(size=4, output_ports_sizes=2) context = demux.CreateDefaultContext() self.assertEqual(demux.get_num_input_ports(), 1) self.assertEqual(demux.get_num_output_ports(), 2) context.FixInputPort(0, BasicVector(input_vec)) output = demux.AllocateOutput() demux.CalcOutput(context, output) for i in range(2): self.assertTrue( np.allclose( output.get_vector_data(i).get_value(), input_vec[2 * i:2 * i + 2]))
def test_demultiplexer(self): # Test demultiplexer with scalar outputs. demux = Demultiplexer(size=4) context = demux.CreateDefaultContext() self.assertEqual(demux.num_input_ports(), 1) self.assertEqual(demux.num_output_ports(), 4) numpy_compare.assert_equal(demux.get_output_ports_sizes(), [1, 1, 1, 1]) input_vec = np.array([1., 2., 3., 4.]) demux.get_input_port(0).FixValue(context, input_vec) output = demux.AllocateOutput() demux.CalcOutput(context, output) for i in range(4): self.assertTrue( np.allclose( output.get_vector_data(i).get_value(), input_vec[i])) # Test demultiplexer with vector outputs. demux = Demultiplexer(size=4, output_ports_size=2) context = demux.CreateDefaultContext() self.assertEqual(demux.num_input_ports(), 1) self.assertEqual(demux.num_output_ports(), 2) numpy_compare.assert_equal(demux.get_output_ports_sizes(), [2, 2]) demux.get_input_port(0).FixValue(context, input_vec) output = demux.AllocateOutput() demux.CalcOutput(context, output) for i in range(2): self.assertTrue( np.allclose( output.get_vector_data(i).get_value(), input_vec[2 * i:2 * i + 2])) # Test demultiplexer with different output port sizes. output_ports_sizes = np.array([1, 2, 1]) num_output_ports = output_ports_sizes.size input_vec = np.array([1., 2., 3., 4.]) demux = Demultiplexer(output_ports_sizes=output_ports_sizes) context = demux.CreateDefaultContext() self.assertEqual(demux.num_input_ports(), 1) self.assertEqual(demux.num_output_ports(), num_output_ports) numpy_compare.assert_equal(demux.get_output_ports_sizes(), output_ports_sizes) demux.get_input_port(0).FixValue(context, input_vec) output = demux.AllocateOutput() demux.CalcOutput(context, output) output_port_start = 0 for i in range(num_output_ports): output_port_size = output.get_vector_data(i).size() self.assertTrue( np.allclose( output.get_vector_data(i).get_value(), input_vec[output_port_start:output_port_start + output_port_size])) output_port_start += output_port_size