Beispiel #1
0
    def test_flatten_n(self):
        input_shape = (29, 87, 10, 20, 30)

        layer = core.Flatten()
        expected_shape = (29, 87 * 10 * 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=2)
        expected_shape = (29, 87, 10 * 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=3)
        expected_shape = (29, 87, 10, 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=4)
        expected_shape = (29, 87, 10, 20, 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        # Not enough dimensions.
        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(n_axes_to_keep=5),
                                       input_shape)

        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(n_axes_to_keep=6),
                                       input_shape)
Beispiel #2
0
    def test_flatten_n(self):
        input_shape = (29, 87, 10, 20, 30)

        actual_shape = base.check_shape_agreement(core.Flatten(), input_shape)
        self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30))

        actual_shape = base.check_shape_agreement(
            core.Flatten(num_axis_to_keep=2), input_shape)
        self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30))

        actual_shape = base.check_shape_agreement(
            core.Flatten(num_axis_to_keep=3), input_shape)
        self.assertEqual(actual_shape, (29, 87, 10, 20 * 30))

        actual_shape = base.check_shape_agreement(
            core.Flatten(num_axis_to_keep=4), input_shape)
        self.assertEqual(actual_shape, (29, 87, 10, 20, 30))

        # Not enough dimensions.
        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(num_axis_to_keep=5),
                                       input_shape)

        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(num_axis_to_keep=6),
                                       input_shape)