def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) layer = core.Flatten() expected_shape = (29, 87 * 10 * 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=2) expected_shape = (29, 87, 10 * 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=3) expected_shape = (29, 87, 10, 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=4) expected_shape = (29, 87, 10, 20, 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) # Not enough dimensions. with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(n_axes_to_keep=5), input_shape) with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(n_axes_to_keep=6), input_shape)
def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) actual_shape = base.check_shape_agreement(core.Flatten(), input_shape) self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30)) actual_shape = base.check_shape_agreement( core.Flatten(num_axis_to_keep=2), input_shape) self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30)) actual_shape = base.check_shape_agreement( core.Flatten(num_axis_to_keep=3), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20 * 30)) actual_shape = base.check_shape_agreement( core.Flatten(num_axis_to_keep=4), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20, 30)) # Not enough dimensions. with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(num_axis_to_keep=5), input_shape) with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(num_axis_to_keep=6), input_shape)