Ejemplo n.º 1
0
 def test_pooling_output_shape_exception(self):
     expected_msg = r"unknown \S+ padding value"
     with self.assertRaisesRegexp(ValueError, expected_msg):
         pooling_output_shape(dimension_size=5,
                              pool_size=2,
                              padding=1,
                              stride=2)
Ejemplo n.º 2
0
    def test_pooling_output_shape(self):
        otuput_shape = pooling_output_shape(None, None, None, None)
        self.assertEqual(otuput_shape, None)

        otuput_shape = pooling_output_shape(
            dimension_size=5, pool_size=2,
            padding='VALID', stride=2)

        self.assertEqual(otuput_shape, 2)

        otuput_shape = pooling_output_shape(
            dimension_size=5, pool_size=2,
            padding='VALID', stride=1)

        self.assertEqual(otuput_shape, 4)

        otuput_shape = pooling_output_shape(
            dimension_size=5, pool_size=2,
            padding='VALID', stride=4)

        self.assertEqual(otuput_shape, 1)
    def test_pooling_output_shape(self):
        otuput_shape = pooling_output_shape(None, None, None, None)
        self.assertEqual(otuput_shape, None)

        otuput_shape = pooling_output_shape(dimension_size=5, pool_size=2,
                                            padding=0, stride=2,
                                            ignore_border=True)
        self.assertEqual(otuput_shape, 2)

        otuput_shape = pooling_output_shape(dimension_size=5, pool_size=2,
                                            padding=0, stride=2,
                                            ignore_border=False)
        self.assertEqual(otuput_shape, 3)

        otuput_shape = pooling_output_shape(dimension_size=5, pool_size=2,
                                            padding=0, stride=1,
                                            ignore_border=False)
        self.assertEqual(otuput_shape, 4)

        otuput_shape = pooling_output_shape(dimension_size=5, pool_size=2,
                                            padding=0, stride=4,
                                            ignore_border=False)
        self.assertEqual(otuput_shape, 2)