Ejemplo n.º 1
0
def max_pool(value, ksize, strides, padding, data_format='NHWC', name=None):
    """Perform max pooling on spatial axes.

    Parameters
    ----------
    value : Tensor
        The input tensor.
    ksize : list of int
        The kernel size with length >= 4.
    strides : list of int
        The strides with length >= 4.
    padding : str
        The padding algorithm. ``VALID`` or ``SAME``.
    data_format : str
        The data format. ``NHWC`` or ``NCHW``.
    name : None or str
        The optional name of op.

    Returns
    -------
    Tensor
        The output tensor.

    """
    if len(ksize) < 4:
        raise ValueError('ksize must be a list with length >=4.')
    if len(strides) < 4:
        raise ValueError('strides must be a list with length >=4.')
    if len(ksize) != len(strides):
        raise ValueError('ksize and strides should have the same length.')
    if len(ksize) == 4:
        if data_format == 'NHWC':
            if ksize[0] != 1 or ksize[3] != 1 or strides[0] != 1 or strides[
                    3] != 1:
                raise ValueError(
                    'The pooling can only be performed on spatial axes.')
            return ops.Pool2d(value, [ksize[1], ksize[2]],
                              [strides[1], strides[2]],
                              padding=padding,
                              data_format=data_format,
                              mode='MAX')
        if data_format == 'NCHW':
            if ksize[0] != 1 or ksize[1] != 1 or strides[0] != 1 or strides[
                    1] != 1:
                raise ValueError(
                    'The pooling can only be performed on spatial axes.')
            return ops.Pool2d(value, [ksize[2], ksize[3]],
                              [strides[2], strides[3]],
                              padding=padding,
                              data_format=data_format,
                              mode='MAX')
    else:
        raise NotImplementedError(
            'Pool{}d has not been implemented yet.'.format(len(ksize) - 2))
Ejemplo n.º 2
0
 def Setup(self, bottom):
     input = bottom[0] if isinstance(bottom, list) else bottom
     super(PoolingLayer, self).Setup(bottom)
     return ops.Pool2d(input, **self._param)
Ejemplo n.º 3
0
 def LayerSetup(self, bottom):
     return _ops.Pool2d(bottom, **self.arguments)