예제 #1
0
 def test_conv_unknown_dim_size(self):
     shape = conv_output_shape(
         dimension_size=None,
         filter_size=5,
         padding='VALID',
         stride=5,
     )
     self.assertEqual(shape, None)
예제 #2
0
 def test_conv_output_shape_int_padding(self):
     output_shape = conv_output_shape(
         dimension_size=10,
         padding=3,
         filter_size=5,
         stride=5,
     )
     self.assertEqual(output_shape, 3)
예제 #3
0
    def test_conv_output_shape_func_exceptions(self):
        with self.assertRaises(ValueError):
            conv_output_shape(dimension_size=5, filter_size=5, border_mode=5,
                              stride='not int')

        with self.assertRaises(ValueError):
            conv_output_shape(dimension_size=5, filter_size='not int',
                              border_mode=5, stride=5)

        with self.assertRaises(ValueError):
            conv_output_shape(dimension_size=5, filter_size=5,
                              border_mode='invalid value', stride=5)
예제 #4
0
    def test_conv_output_shape_func_exceptions(self):
        with self.assertRaises(ValueError):
            conv_output_shape(dimension_size=5,
                              filter_size=5,
                              border_mode=5,
                              stride='not int')

        with self.assertRaises(ValueError):
            conv_output_shape(dimension_size=5,
                              filter_size='not int',
                              border_mode=5,
                              stride=5)

        with self.assertRaises(ValueError):
            conv_output_shape(dimension_size=5,
                              filter_size=5,
                              border_mode='invalid value',
                              stride=5)
예제 #5
0
    def test_conv_output_shape_func_exceptions(self):
        with self.assertRaises(ValueError):
            # Wrong stride value
            conv_output_shape(
                dimension_size=5, filter_size=5,
                padding='VALID', stride='not int')

        with self.assertRaises(ValueError):
            # Wrong filter size value
            conv_output_shape(
                dimension_size=5, filter_size='not int',
                padding='SAME', stride=5)

        with self.assertRaisesRegexp(ValueError, "unknown \S+ padding value"):
            # Wrong padding value
            conv_output_shape(
                dimension_size=5, filter_size=5,
                padding=1.5, stride=5,
            )
예제 #6
0
 def test_conv_unknown_dim_size(self):
     shape = conv_output_shape(dimension_size=None, filter_size=5,
                               padding=5, stride=5)
     self.assertEqual(shape, None)