예제 #1
0
    def test_max_pool_3d_3D_deprecated_interface(self):
        rng = np.random.RandomState(utt.fetch_seed())
        maxpoolshps = ((1, 1, 1), (3, 2, 1))
        imval = rng.rand(4, 5, 6)
        images = dtensor3()

        for maxpoolshp, ignore_border, mode in product(
                maxpoolshps,
            [True, False],
            ["max", "sum", "average_inc_pad", "average_exc_pad"],
        ):
            # print 'maxpoolshp =', maxpoolshp
            # print 'ignore_border =', ignore_border
            numpy_output_val = self.numpy_max_pool_nd(imval,
                                                      maxpoolshp,
                                                      ignore_border,
                                                      mode=mode)
            output = pool_3d(
                input=images,
                ds=maxpoolshp,
                ignore_border=ignore_border,
                st=maxpoolshp,
                padding=(0, 0, 0),
                mode=mode,
            )
            output_val = function([images], output)(imval)
            utt.assert_allclose(output_val, numpy_output_val)

            def mp(input):
                return pool_3d(input, maxpoolshp, ignore_border, mode=mode)
예제 #2
0
    def test_DownsampleFactorMax(self):
        rng = np.random.RandomState(utt.fetch_seed())
        # maxpool, input size
        examples = (
            ((2, ), (16, )),
            (
                (2, ),
                (
                    4,
                    16,
                ),
            ),
            (
                (2, ),
                (
                    4,
                    2,
                    16,
                ),
            ),
            ((1, 1), (4, 2, 16, 16)),
            ((2, 2), (4, 2, 16, 16)),
            ((3, 3), (4, 2, 16, 16)),
            ((3, 2), (4, 2, 16, 16)),
            ((3, 2, 2), (3, 2, 16, 16, 16)),
            ((2, 2, 3, 2), (3, 2, 6, 6, 6, 5)),
        )

        for example, ignore_border, mode in product(
                examples,
            [True, False],
            ["max", "sum", "average_inc_pad", "average_exc_pad"],
        ):
            (maxpoolshp, inputsize) = example
            imval = rng.rand(*inputsize)
            images = aesara.shared(imval)

            # Pure Numpy computation
            numpy_output_val = self.numpy_max_pool_nd(imval,
                                                      maxpoolshp,
                                                      ignore_border,
                                                      mode=mode)

            # The pool_2d or pool_3d helper methods
            if len(maxpoolshp) == 2:
                output = pool_2d(images, maxpoolshp, ignore_border, mode=mode)
                f = function(
                    [],
                    [
                        output,
                    ],
                )
                output_val = f()
                utt.assert_allclose(output_val, numpy_output_val)
            elif len(maxpoolshp) == 3:
                output = pool_3d(images, maxpoolshp, ignore_border, mode=mode)
                f = function(
                    [],
                    [
                        output,
                    ],
                )
                output_val = f()
                utt.assert_allclose(output_val, numpy_output_val)

            # Pool op
            maxpool_op = Pool(ndim=len(maxpoolshp),
                              ignore_border=ignore_border,
                              mode=mode)(images, maxpoolshp)

            output_shape = Pool.out_shape(
                imval.shape,
                maxpoolshp,
                ndim=len(maxpoolshp),
                ignore_border=ignore_border,
            )
            utt.assert_allclose(np.asarray(output_shape),
                                numpy_output_val.shape)
            f = function([], maxpool_op)
            output_val = f()
            utt.assert_allclose(output_val, numpy_output_val)
예제 #3
0
 def mp(input):
     return pool_3d(input, maxpoolshp, ignore_border, mode=mode)