def mp(input, grad): out = DownsampleFactorMax(maxpoolshp, ignore_border=ignore_border, st=stride)(input) grad_op = MaxPoolGrad(maxpoolshp, ignore_border=ignore_border, st=stride) return grad_op(input, out, grad)
def test_infer_shape(self): image = tensor.dtensor4() maxout = tensor.dtensor4() gz = tensor.dtensor4() rng = numpy.random.RandomState(utt.fetch_seed()) maxpoolshps = ((1, 1), (2, 2), (3, 3), (2, 3), (3, 2)) image_val = rng.rand(4, 6, 7, 9) out_shapes = [[[[4, 6, 7, 9], [4, 6, 7, 9]], [[4, 6, 3, 4], [4, 6, 4, 5]], [[4, 6, 2, 3], [4, 6, 3, 3]], [[4, 6, 3, 3], [4, 6, 4, 3]], [[4, 6, 2, 4], [4, 6, 3, 5]]], [[None, None], [[4, 6, 4, 5], None], [[4, 6, 3, 3], None], [[4, 6, 4, 3], None], [[4, 6, 3, 5], None]], [[None, None], [None, None], [[4, 6, 3, 4], None], [[4, 6, 4, 4], None], [None, None]]] for i, maxpoolshp in enumerate(maxpoolshps): for j, ignore_border in enumerate([True, False]): for k, padding in enumerate([(0, 0), (1, 1), (1, 2)]): if out_shapes[k][i][j] is None: continue # checking shapes generated by DownsampleFactorMax self._compile_and_check([image], [DownsampleFactorMax(maxpoolshp, ignore_border=ignore_border, padding=padding)(image)], [image_val], DownsampleFactorMax) # checking shapes generated by MaxPoolGrad maxout_val = rng.rand(*out_shapes[k][i][j]) gz_val = rng.rand(*out_shapes[k][i][j]) self._compile_and_check([image, maxout, gz], [MaxPoolGrad(maxpoolshp, ignore_border=ignore_border, padding=padding) (image, maxout, gz)], [image_val, maxout_val, gz_val], MaxPoolGrad, warn=False) # checking with broadcastable input image = tensor.tensor(dtype='float64', broadcastable=(False, False, True, True)) image_val = rng.rand(4, 6, 1, 1) self._compile_and_check( [image], [DownsampleFactorMax((2, 2), ignore_border=True, padding=(0, 0))(image)], [image_val], DownsampleFactorMax)
def mp(input, grad): out = DownsampleFactorMax( maxpoolsize, ignore_border=True, st=stridesize, padding=paddingsize, )(input) grad_op = MaxPoolGrad(maxpoolsize, ignore_border=True, st=stridesize, padding=paddingsize) return grad_op(input, out, grad)