Exemplo n.º 1
0
def max_pool(value, ksize, strides, pads=(0, 0, 0, 0),
             padding=None, data_format="NCHW", name=None):
    """
    Performs the max pooling on the input.

      Args:
        value: A 4-D `Tensor` with type `tf.float32`.
        ksize: A list of ints that has length 4.
        strides: A list of ints that has length 4.
        pads: A list of ints or a int.
        padding: A string, either `'VALID'` or `'SAME'`. (deprecated)
        data_format: A string. 'NHWC' and 'NCHW' are supported.
        name: Optional name for the operation.

      Returns:
        A `Tensor` with type `tf.float32`. The max pooled output tensor.
    """

    if len(strides) != 4:
        raise ValueError('strides must be a list of length 4.')

    if len(ksize) != 4:
        raise ValueError('strides must be a list of length 4.')

    if data_format == 'NCHW':
        if pads is None: pads = 0
        return ops.Pool2D(value,
                          kernel_size=ksize[2:],
                          stride=strides[2:],
                          pad=pads,
                          mode='MAX_POOLING')

    else: raise NotImplementedError()
Exemplo n.º 2
0
def max_pool(value,
             ksize,
             strides,
             pads=(0, 0, 0, 0),
             padding=None,
             data_format="NCHW",
             name=None):

    if len(strides) != 4:
        raise ValueError('strides must be a list of length 4.')

    if len(ksize) != 4:
        raise ValueError('strides must be a list of length 4.')

    if data_format == 'NCHW':
        if pads is None: pads = 0
        return ops.Pool2D(value,
                          kernel_size=ksize[2:],
                          stride=strides[2:],
                          pad=pads,
                          mode='MAX_POOLING')

    else:
        raise NotImplementedError()
Exemplo n.º 3
0
 def Setup(self, bottom):
     input = bottom[0] if isinstance(bottom, list) else bottom
     super(PoolingLayer, self).Setup(bottom)
     return ops.Pool2D(input, **self._param)