コード例 #1
0
    def test_demultiplexer(self):
        # Test demultiplexer with scalar outputs.
        demux = Demultiplexer(size=4)
        context = demux.CreateDefaultContext()
        self.assertEqual(demux.get_num_input_ports(), 1)
        self.assertEqual(demux.get_num_output_ports(), 4)

        input_vec = np.array([1., 2., 3., 4.])
        context.FixInputPort(0, BasicVector(input_vec))
        output = demux.AllocateOutput()
        demux.CalcOutput(context, output)

        for i in range(4):
            self.assertTrue(
                np.allclose(
                    output.get_vector_data(i).get_value(), input_vec[i]))

        # Test demultiplexer with vector outputs.
        demux = Demultiplexer(size=4, output_ports_sizes=2)
        context = demux.CreateDefaultContext()
        self.assertEqual(demux.get_num_input_ports(), 1)
        self.assertEqual(demux.get_num_output_ports(), 2)

        context.FixInputPort(0, BasicVector(input_vec))
        output = demux.AllocateOutput()
        demux.CalcOutput(context, output)

        for i in range(2):
            self.assertTrue(
                np.allclose(
                    output.get_vector_data(i).get_value(),
                    input_vec[2 * i:2 * i + 2]))
コード例 #2
0
    def test_demultiplexer(self):
        # Test demultiplexer with scalar outputs.
        demux = Demultiplexer(size=4)
        context = demux.CreateDefaultContext()
        self.assertEqual(demux.num_input_ports(), 1)
        self.assertEqual(demux.num_output_ports(), 4)
        numpy_compare.assert_equal(demux.get_output_ports_sizes(),
                                   [1, 1, 1, 1])

        input_vec = np.array([1., 2., 3., 4.])
        demux.get_input_port(0).FixValue(context, input_vec)
        output = demux.AllocateOutput()
        demux.CalcOutput(context, output)

        for i in range(4):
            self.assertTrue(
                np.allclose(
                    output.get_vector_data(i).get_value(), input_vec[i]))

        # Test demultiplexer with vector outputs.
        demux = Demultiplexer(size=4, output_ports_size=2)
        context = demux.CreateDefaultContext()
        self.assertEqual(demux.num_input_ports(), 1)
        self.assertEqual(demux.num_output_ports(), 2)
        numpy_compare.assert_equal(demux.get_output_ports_sizes(), [2, 2])

        demux.get_input_port(0).FixValue(context, input_vec)
        output = demux.AllocateOutput()
        demux.CalcOutput(context, output)

        for i in range(2):
            self.assertTrue(
                np.allclose(
                    output.get_vector_data(i).get_value(),
                    input_vec[2 * i:2 * i + 2]))

        # Test demultiplexer with different output port sizes.
        output_ports_sizes = np.array([1, 2, 1])
        num_output_ports = output_ports_sizes.size
        input_vec = np.array([1., 2., 3., 4.])
        demux = Demultiplexer(output_ports_sizes=output_ports_sizes)
        context = demux.CreateDefaultContext()
        self.assertEqual(demux.num_input_ports(), 1)
        self.assertEqual(demux.num_output_ports(), num_output_ports)
        numpy_compare.assert_equal(demux.get_output_ports_sizes(),
                                   output_ports_sizes)

        demux.get_input_port(0).FixValue(context, input_vec)
        output = demux.AllocateOutput()
        demux.CalcOutput(context, output)

        output_port_start = 0
        for i in range(num_output_ports):
            output_port_size = output.get_vector_data(i).size()
            self.assertTrue(
                np.allclose(
                    output.get_vector_data(i).get_value(),
                    input_vec[output_port_start:output_port_start +
                              output_port_size]))
            output_port_start += output_port_size