예제 #1
0
    def version_11(cls, node, **kwargs):
        x = kwargs["tensor_dict"][node.inputs[0]]
        x_rank = len(x.get_shape())
        storage_format, compute_format = get_data_format(x_rank)
        attrs = copy.deepcopy(node.attrs)
        attrs["data_format"] = storage_format
        mode = attrs.get("mode", "DCR")

        if mode == "CRD":
            # need native computation
            bsize = attrs.get("blocksize")
            batch, channel, height, width = x.shape
            csize = channel // (bsize**2)

            reshape_node = tf.reshape(
                x, [batch, csize, bsize, bsize, height, width])
            transpose_node = tf.transpose(reshape_node,
                                          perm=[0, 1, 4, 2, 5, 3])
            return [
                tf.reshape(transpose_node,
                           [batch, csize, height * bsize, width * bsize])
            ]

        else:
            return [
                cls.make_tensor_from_onnx_node(node,
                                               attrs=attrs,
                                               c_first_cuda_only=True,
                                               **kwargs)
            ]
예제 #2
0
    def _common(cls, node, **kwargs):
        x = kwargs["tensor_dict"][node.inputs[0]]
        x_rank = len(x.get_shape())
        storage_format, compute_format = get_data_format(x_rank)
        attrs = copy.deepcopy(node.attrs)
        attrs["data_format"] = storage_format

        if sys_config.device == 'CUDA' and x.dtype not in {
                tf.uint8, tf.float16, tf.float32
        }:
            # Tensorflow GPU version doesn't support these datatype but CPU version support
            with tf.device("/cpu:0"):  # run it on cpu
                compute_format = compute_format.replace("C", "") + "C"
                pre_perm = get_perm_from_formats(storage_format,
                                                 compute_format)
                post_perm = get_perm_from_formats(compute_format,
                                                  storage_format)
                x_t = tf.transpose(x, perm=pre_perm)
                y = tf.nn.space_to_depth(x_t, attrs["blocksize"],
                                         compute_format)
                y = tf.transpose(y, perm=post_perm)
        else:
            y = cls.make_tensor_from_onnx_node(node,
                                               attrs=attrs,
                                               c_first_cuda_only=True,
                                               **kwargs)
        return [y]
예제 #3
0
    def pool_v11(cls, node, input_dict, pooling_type, strict=True):
        x = input_dict[node.inputs[0]]

        kernel_shape = node.attrs["kernel_shape"]

        spatial_size = len(kernel_shape)
        x_rank = spatial_size + 2

        kernel_shape = node.attrs["kernel_shape"]
        strides = node.attrs.get("strides", [1] * spatial_size)
        dilations = node.attrs.get("dilations", [1] * spatial_size)
        ceil_mode = bool(node.attrs.get("ceil_mode", 0))
        pads = node.attrs.get("auto_pad", "NOTSET")
        if pads == "NOTSET":
            pads = node.attrs.get("pads", [0] * spatial_size * 2)

        if spatial_size > 3:
            exception.OP_UNSUPPORTED_EXCEPT(
                "MaxPool with {}D input".format(x_rank), "Tensorflow")
        if pooling_type == "MAX_WITH_ARGMAX" and x_rank != 4:
            exception.OP_UNSUPPORTED_EXCEPT(
                "MaxPool with {}D input".format(x_rank), "Tensorflow")
        if node.attrs.get("storage_order", 0) != 0:
            exception.OP_UNSUPPORTED_EXCEPT("MaxPool with column major",
                                            "Tensorflow")

        storage_format, _ = get_data_format(x_rank)

        need_trans = storage_format.startswith("NC")
        if need_trans:
            compute_format = "N" + storage_format[2:] + "C"
            x = tf.transpose(x,
                             perm=get_perm_from_formats(
                                 storage_format, compute_format))

        dp = DilatedPooling(input=x,
                            kernel_shape=kernel_shape,
                            strides=strides,
                            dilations=dilations,
                            padding=pads,
                            ceil_mode=ceil_mode)

        # select correct op depending on the pooling type
        pooling_op = lambda : (dp.dilated_maxpool(), None) if \
            pooling_type == "MAX" else dp.dilated_maxpool_with_argmax()

        # select the correct transpose ops depending on the input storage format
        perm = get_perm_from_formats(compute_format, storage_format)
        postprocess_op = lambda pooled, argmax: (
            tf.transpose(pooled, perm=perm)
            if need_trans else pooled, tf.transpose(argmax, perm=perm)
            if need_trans and argmax is not None else argmax)

        pooled, argmax = pooling_op()
        pooled, argmax = postprocess_op(pooled, argmax)

        result = [pooled] if argmax is None else [pooled, argmax]

        return result
예제 #4
0
 def version_1(cls, node, **kwargs):
     x = kwargs["tensor_dict"][node.inputs[0]]
     x_rank = len(x.get_shape())
     storage_format, compute_format = get_data_format(x_rank)
     attrs = copy.deepcopy(node.attrs)
     attrs["data_format"] = storage_format
     return [
         cls.make_tensor_from_onnx_node(node,
                                        attrs=attrs,
                                        c_first_cuda_only=True,
                                        **kwargs)
     ]
  def c_last_only(cls, tf_func, inputs, attrs):
    """ Handle operator that channel last only is supported.
    Add two transposes anyway.

    :param tf_func: Callable Tf function.
    :param inputs: Inputs tensor.
    :param attrs: Attributes.
    :return: Tensor.
    """
    storage_format, compute_format = get_data_format(len(inputs[0].get_shape()))
    compute_format = compute_format.replace("C", "") + "C"
    return cls._tuck_transpose(tf_func, inputs, attrs,
                               (storage_format, compute_format))
예제 #6
0
    def max_unpool(cls, node, input_dict):
        """
            MaxUnpooling operation
        """
        x = input_dict[node.inputs[0]]
        ind = input_dict[node.inputs[1]]
        if len(node.inputs) > 2:
            output_shape = input_dict.get(node.inputs[2], None)
        else:
            output_shape = None

        kernel_shape = node.attrs["kernel_shape"]

        spatial_size = len(kernel_shape)
        x_rank = spatial_size + 2
        storage_format, _ = get_data_format(x_rank)

        # if strides are not provided default is 1 along each spatial axis
        strides = node.attrs.get("strides", [1] * spatial_size)
        pads = node.attrs.get("pads", None)

        input_shape = tf_shape(x)
        default_shape = cls._get_default_shape(input_shape, kernel_shape,
                                               strides)

        need_trans = storage_format != "NHWC"
        if need_trans:
            x = tf.transpose(x,
                             perm=get_perm_from_formats(
                                 storage_format, "NHWC"))
            ind = tf.transpose(ind,
                               perm=get_perm_from_formats(
                                   storage_format, "NHWC"))

        # default_shape to NHWC storage format
        default_shape = [input_shape[0]] + default_shape + \
                        [input_shape[1]]

        unpooled = cls._unpool(x, ind, default_shape)

        if need_trans:
            unpooled = tf.transpose(unpooled,
                                    perm=get_perm_from_formats(
                                        "NHWC", storage_format))

        if output_shape is not None:
            pads = cls._get_pads_from_output_shape(unpooled, output_shape)
        if pads is not None:
            unpooled = cls._pad_output(unpooled, pads, 0)

        return [unpooled]
예제 #7
0
 def _tuck_transpose(cls, tf_func, inputs, attrs, data_format=None):
     x = inputs[0]
     x_rank = len(x.get_shape())
     if not data_format:
         data_format = get_data_format(x_rank)
     pre_perm = get_perm_from_formats(data_format[0], data_format[1])
     post_perm = get_perm_from_formats(data_format[1], data_format[0])
     attrs["data_format"] = data_format[1]
     if pre_perm != list(range(x_rank)):
         x_t = tf.transpose(x, perm=pre_perm)
         y = cls._run_tf_func(tf_func, [x_t] + inputs[1:], attrs)
         y_t = tf.transpose(y, perm=post_perm)
         return y_t
     return cls._run_tf_func(tf_func, inputs, attrs)
예제 #8
0
    def __init__(self,
                 input,
                 kernel_shape,
                 strides,
                 dilations,
                 padding="VALID",
                 ceil_mode=False,
                 count_include_pad=False,
                 pooling_type="MAX",
                 p=2):
        self.input = tf.convert_to_tensor(input)

        self.kernel_shape = kernel_shape
        self.strides = strides
        self.dilations = dilations
        self.padding = padding
        self.is_explicit_padding = type(padding) is list
        self.ceil_mode = ceil_mode
        self.count_include_pad = count_include_pad
        self.pooling_type = pooling_type.upper()
        self.p = p

        self.is_known_shape = self.input.shape.is_fully_defined()
        self.spatial_size = len(kernel_shape)
        self.input_rank = self.spatial_size + 2

        # if the rank is not defined, set it to the calculated input_rank
        # rank should be known for ops like tf.gather_nd
        if not input.shape.rank:
            input.set_shape([None] * self.input_rank)
        self.orig_input_shape = tf_shape(input)
        self.input_shape = self.orig_input_shape

        if pooling_type.startswith("MAX"):
            self.padding_constant = input.dtype.min
        else:
            self.padding_constant = 0

        self.storage_format, self.compute_format = get_data_format(
            self.spatial_size + 2)
        self.need_trans = self.storage_format != self.compute_format
예제 #9
0
    def _common(cls, node, **kwargs):
        tensor_dict = kwargs['tensor_dict']
        feat = tensor_dict[node.inputs[0]]
        boxes = tensor_dict[node.inputs[1]]
        indx = tensor_dict[node.inputs[2]]
        output_height = node.attrs['output_height']
        output_width = node.attrs['output_width']
        sampling_ratio = node.attrs['sampling_ratio']
        spatial_scale = node.attrs['spatial_scale']
        adaptive_ratio = False
        if sampling_ratio <= 0:
            sampling_ratio = int((output_height + output_width) / 2)
            adaptive_ratio = True
            logger.warning("Do not fully support sampling_ratio <= 0.")

        boxes = boxes * spatial_scale

        feat_rank = len(feat.shape)
        storage_format, _ = get_data_format(feat_rank)
        need_trans = storage_format.startswith("NC")
        if need_trans:
            compute_format = "N" + storage_format[2:] + "C"
            feat = tf.transpose(feat,
                                perm=get_perm_from_formats(
                                    storage_format, compute_format))

        ret = crop_and_resize(feat,
                              boxes,
                              tf.cast(indx, tf.int32),
                              (output_height, output_width),
                              sampling_ratio,
                              adaptive_ratio=adaptive_ratio)
        ret = tf.nn.avg_pool(ret, [1, sampling_ratio, sampling_ratio, 1],
                             [1, sampling_ratio, sampling_ratio, 1],
                             padding='SAME',
                             data_format='NHWC')
        ret = tf.transpose(ret, perm=(0, 3, 1, 2))
        return [ret]
예제 #10
0
  def conv(cls, node, input_dict, transpose=False):
    """ Convolution method for both conv and transposed conv
    For transposed conv,
      Attr pads is not used for input, but declares how much output is padded.
      Here, output means output from transposed conv which already pad output_padding if set.
      So the pseudo explanation for output should be:
        output = conv_transpose_output + output_padding - pads
      And conv_transpose_output shape should be:
        conv_transpose_output_shape[i] = strides[i] * (input_shape[i] - 1) + kernel_shape[i]
    """
    x = input_dict[node.inputs[0]]
    x_rank = len(x.get_shape())
    x_shape = tf_shape(x, tf.int32)
    spatial_size = x_rank - 2

    storage_format, compute_format = get_data_format(x_rank)
    compute_c_idx = compute_format.find("C")
    spatial_format = "".join([d for d in compute_format if d not in ["N", "C"]])

    in_weights = input_dict[node.inputs[1]]
    weights_rank = len(in_weights.get_shape())
    if transpose:
      # Translate weights from (C x M x KH x KW) to (KH x KW X M X C)
      perm = list(range(2, weights_rank)) + [1, 0]
    else:
      # Translate weights from (M x C x KH x KW) to (KH x KW X C X M)
      perm = list(range(2, weights_rank)) + [1, 0]

    if "kernel_shape" in node.attrs.keys():
      kernel_shape = node.attrs["kernel_shape"]
      if in_weights.get_shape().is_fully_defined():
        assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
            "kernel_shape "
            "attr of convolution does not match the actual weight "
            "passed to this operation, attr {}, actual {}").format(
                kernel_shape,
                in_weights.get_shape().as_list())
    else:
      kernel_shape = tf_shape(in_weights, tf.int32)[2:]

    weights = tf.transpose(in_weights, perm)
    dilations = node.attrs.get("dilations", [1] * spatial_size)
    strides = node.attrs.get("strides", [1] * spatial_size)

    pads = node.attrs.get("pads", [0, 0] * spatial_size)

    # Check auto_pad nonexistent or NOTSET first
    if "auto_pad" not in node.attrs or node.attrs["auto_pad"] == "NOTSET":
      if not transpose:
        if pads != [0, 0] * spatial_size:
          x = PadMixin.get_padding_as_op(x, pads)
        pad_mode = "VALID"
      else:
        pad_mode = "NOTSET"
    # Then we use auto_pad to setup pad_mode
    elif node.attrs["auto_pad"] == "SAME_UPPER":
      pad_mode = "SAME"
    elif node.attrs["auto_pad"] == "VALID":
      pad_mode = "VALID"
    elif node.attrs["auto_pad"] == "SAME_LOWER":
      pad_mode = PAD_TF_INCOMPATIBLE
    else:
      raise ValueError("Invalid auto_pad attribute: {}".format(
          node.attrs["auto_pad"]))

    # Currently auto_pad = SAME_LOWER is not supported
    if pad_mode is PAD_TF_INCOMPATIBLE:
      if transpose:
        exception.OP_UNSUPPORTED_EXCEPT(
            "ConvTranspose with auto_pad `SAME_LOWER`", "Tensorflow")
      else:
        exception.OP_UNSUPPORTED_EXCEPT("Conv with auto_pad `SAME_LOWER`",
                                        "Tensorflow")

    group = node.attrs.get("group", 1)
    weight_shape = weights.get_shape().as_list()
    # Is this convolution depthwise we can support?
    depthwise = (x_rank == 4 and len(weight_shape) == 4 and group != 1 and
                 not transpose and not (None in weight_shape))
    if depthwise and isinstance(x_shape, np.ndarray):
      depthwise = bool(group == x_shape[1])

    if depthwise is True:
      # Depthwise convolution.
      # The convolution kernel layout in tf.depthwise_conv is:
      # [filter_height, filter_width, in_channels, channel_multiplier]
      # Weight is now (KH x KW X C/g X M), or more precisely, (KH x KW X C/g X (g * M/g)),
      # we reshape it to (KH x KW x C x M/g)
      # NOTE: Assuming weight has fixed shape.

      depthwise_filter_shape = weight_shape[0:2] + [
          -1, weight_shape[3] // group
      ]
      weights = tf.reshape(weights, depthwise_filter_shape)

      if not sys_config.device == 'CUDA':
        # transpose input to NHWC layout
        x = tf.transpose(x,
                         perm=get_perm_from_formats(storage_format,
                                                    compute_format))
      weight_groups = [weights]
      xs = [x]
    else:
      weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)
      if sys_config.device == 'CUDA':
        if group == 1:
          xs = [x]
        else:
          xs = tf.split(x, num_or_size_splits=group, axis=1)
      else:
        x = tf.transpose(x,
                         perm=get_perm_from_formats(storage_format,
                                                    compute_format))
        if group == 1:
          xs = [x]
        else:
          xs = tf.split(x, num_or_size_splits=group, axis=-1)

    if transpose:
      if dilations != [1] * spatial_size:
        raise RuntimeError("Cannot set non-1 dilation for conv transpose.")
      convolved = []
      # this is a workaround for tensorflow AutoGraph not detecting
      # corretly x. This is fixed in tf>=2.2.0
      x = None
      for (x, weight) in zip(xs, weight_groups):
        x_spatial_shape = [
            x_shape[storage_format.find(d)] for d in spatial_format
        ]
        weights_shape = tf_shape(weights, tf.int32)
        output_shape = node.attrs.get("output_shape", None)
        conv_output_shape = [x_shape[storage_format.find("N")]]

        # calculate output shape
        if pad_mode == "NOTSET":
          if output_shape is None:
            conv_output_shape += [
                strides[i] * x_spatial_shape[i] - strides[i] +
                (kernel_shape[i] - 1) * dilations[i] + 1
                for i in list(range(spatial_size))
            ]
          else:
            conv_output_shape += [
                s + pads[i] + pads[spatial_size + i]
                for i, s in enumerate(output_shape[-2:])
            ]
          conv_output_shape.insert(compute_c_idx, weights_shape[-2])

          # make strides to match input rank
          strides_full = [1] + strides
          strides_full.insert(compute_c_idx, 1)

          # get corresponding function in tf
          if spatial_size == 1:
            conv_func = tf.nn.conv1d_transpose
            strides_full = strides[0]
          elif spatial_size == 2:
            conv_func = tf.nn.conv2d_transpose
          elif spatial_size == 3:
            conv_func = tf.nn.conv3d_transpose
          else:
            raise NotImplementedError(
                "Transposed convolution for {}d is not implemented in Tensorflow"
                .format(spatial_size))

          # use raw input x to do transposed conv
          conv_rs = conv_func(x,
                              weight,
                              conv_output_shape,
                              strides_full,
                              padding="VALID",
                              data_format=compute_format)

          # pad output first by output_padding attr
          if "output_padding" in node.attrs and output_shape is None:
            output_padding = [[0, 0]
                             ] + [[0, p] for p in node.attrs["output_padding"]]
            output_padding.insert(compute_c_idx, [0, 0])
            conv_rs = tf.pad(conv_rs, output_padding)

          # remove pads set in pads attr
          conv_rs_shape = tf_shape(conv_rs, tf.int32)
          conv_rs_shape_list = [
              conv_rs_shape[i] for i in range(conv_rs.shape.rank)
          ]
          begin = [0] + pads[:spatial_size]
          begin.insert(compute_c_idx, 0)
          size = [
              s if d in ["N", "C"] else s - pads[spatial_format.find(d)] -
              pads[spatial_format.find(d) + spatial_size]
              for d, s in zip(compute_format, conv_rs_shape_list)
          ]

          conv_rs = tf.slice(conv_rs, begin=begin, size=size)

          convolved.append(conv_rs)
        else:
          # No need to check pads if auto_pad is specifically provided.
          # The assumption is that once auto_pad is provided as either VALID
          # or SAME_UPPER (SAME_LOWER is currently not supported in TF) the
          # output_shape will always be inferred. That is, the output_shape
          # and output_padding will not be used in this case.
          if pad_mode == "VALID":
            conv_output_shape += [
                strides[i] * (x_spatial_shape[i] - 1) + weights_shape[i]
                for i in list(range(spatial_size))
            ]
          else:
            conv_output_shape += [
                strides[i] * x_spatial_shape[i]
                for i in list(range(spatial_size))
            ]
          conv_output_shape.insert(compute_c_idx, weights_shape[-2])

          # make strides to match input rank
          strides_full = [1] + strides
          strides_full.insert(compute_c_idx, 1)

          # get corresponding function in tf
          if spatial_size == 1:
            conv_func = tf.nn.conv1d_transpose
            strides_full = strides[0]
          elif spatial_size == 2:
            conv_func = tf.nn.conv2d_transpose
          elif spatial_size == 3:
            conv_func = tf.nn.conv3d_transpose
          else:
            raise NotImplementedError(
                "Transposed convolution for {}d is not implemented in Tensorflow"
                .format(spatial_size))

          # use raw input x to do transposed conv
          conv_rs = conv_func(x,
                              weight,
                              conv_output_shape,
                              strides_full,
                              padding=pad_mode,
                              data_format=compute_format)
          convolved.append(conv_rs)

    else:  # not transpose:
      if depthwise is True:
        if compute_format == "NHWC":
          strides = [1] + strides + [1]
        elif compute_format == 'NCHW':
          strides = [1, 1] + strides
        else:
          raise ValueError("Invalid compute_format: {}".format(compute_format))

        convolved = [
            tf.nn.depthwise_conv2d(x,
                                   weight,
                                   padding=pad_mode,
                                   strides=strides,
                                   dilations=dilations,
                                   data_format=compute_format)
            for (x, weight) in zip(xs, weight_groups)
        ]

      else:
        convolved = [
            tf.nn.convolution(x,
                              weight,
                              padding=pad_mode,
                              strides=strides,
                              dilations=dilations,
                              data_format=compute_format)
            for (x, weight) in zip(xs, weight_groups)
        ]

    if len(node.inputs) == 2:
      if sys_config.device == 'CUDA':
        output = tf.concat(convolved, axis=1)
      else:
        output = tf.concat(convolved, axis=-1)
        output = tf.transpose(output,
                              perm=get_perm_from_formats(
                                  compute_format, storage_format))
    else:
      bias = input_dict[node.inputs[2]]
      bias = cls.explicit_broadcast([x, bias], compute_c_idx)

      if sys_config.device == 'CUDA':
        output = tf.concat(convolved, axis=1)
        output = tf.add(output, bias)
      else:
        output = tf.concat(convolved, axis=-1)
        output = tf.add(output, bias)
        output = tf.transpose(output,
                              perm=get_perm_from_formats(
                                  compute_format, storage_format))

    return [output]
예제 #11
0
  def conv(cls, node, input_dict, transpose=False):
    """ Convolution method for both conv and transposed conv
    For transposed conv,
      Attr pads is not used for input, but declares how much output is padded.
      Here, output means output from transposed conv which already pad output_padding if set.
      So the pseudo explanation for output should be:
        output = conv_transpose_output + output_padding - pads
      And conv_transpose_output shape should be:
        conv_transpose_output_shape[i] = strides[i] * (input_shape[i] - 1) + kernel_shape[i]
    """
    x = input_dict[node.inputs[0]]
    x_rank = len(x.get_shape())
    x_shape = x.get_shape().as_list()
    spatial_size = x_rank - 2

    support_cuda = supports_device("CUDA")
    storage_format, compute_format = get_data_format(x_rank)
    compute_c_idx = compute_format.find("C")
    spatial_format = "".join([d for d in compute_format if d not in ["N", "C"]])

    in_weights = input_dict[node.inputs[1]]
    weights_rank = len(in_weights.get_shape())
    if transpose:
      # Translate weights from (C x M x KH x KW) to (KH x KW X M X C)
      perm = list(range(2, weights_rank)) + [1, 0]
    else:
      # Translate weights from (M x C x KH x KW) to (KH x KW X C X M)
      perm = list(range(2, weights_rank)) + [1, 0]

    if "kernel_shape" in node.attrs.keys():
      kernel_shape = node.attrs["kernel_shape"]
      assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
          "kernel_shape "
          "attr of convolution does not match the actual weight "
          "passed to this operation, attr {}, actual {}").format(
              kernel_shape,
              in_weights.get_shape().as_list())

    weights = tf.transpose(in_weights, perm)
    dilations = node.attrs.get("dilations", [1] * spatial_size)
    strides = node.attrs.get("strides", [1] * spatial_size)

    pads = node.attrs.get("pads", [0, 0] * spatial_size)

    if not transpose:
      x = PadMixin.get_padding_as_op(x, pads)

    group = node.attrs.get("group", 1)

    weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)

    if support_cuda:
      xs = tf.split(x, num_or_size_splits=group, axis=1)
    else:
      x = tf.transpose(
          x, perm=get_perm_from_formats(storage_format, compute_format))
      xs = tf.split(x, num_or_size_splits=group, axis=-1)

    if transpose:
      if dilations != [1] * spatial_size:
        raise RuntimeError("Cannot set non-1 dilation for conv transpose.")
      convolved = []
      for (x, weight) in zip(xs, weight_groups):
        x_spatial_shape = [
            x_shape[storage_format.find(d)] for d in spatial_format
        ]
        weights_shape = weights.get_shape().as_list()

        # calculate output shape
        output_shape = node.attrs.get("output_shape", None)
        conv_output_shape = [x_shape[storage_format.find("N")]]
        if output_shape is None:
          conv_output_shape += [
              strides[i] * (x_spatial_shape[i] - 1) + weights_shape[i]
              for i in list(range(spatial_size))
          ]
        else:
          conv_output_shape += [
              s + pads[i] + pads[spatial_size + i]
              for i, s in enumerate(output_shape[-2:])
          ]
        conv_output_shape.insert(compute_c_idx, weights_shape[-2])

        # make strides to match input rank
        strides_full = [1] + strides
        strides_full.insert(compute_c_idx, 1)

        # get corresponding function in tf
        if spatial_size == 1:
          conv_func = tf.contrib.nn.conv1d_transpose
          strides_full = strides[0]
        elif spatial_size == 2:
          conv_func = tf.nn.conv2d_transpose
        elif spatial_size == 3:
          conv_func = tf.nn.conv3d_transpose
        else:
          raise NotImplementedError(
              "Transposed convolution for {}d is not implemented in Tensorflow".
              format(spatial_size))

        # use raw input x to do transposed conv
        conv_rs = conv_func(
            x,
            weight,
            conv_output_shape,
            strides_full,
            padding="VALID",
            data_format=compute_format)

        # pad output first by output_padding attr
        if "output_padding" in node.attrs and output_shape is None:
          output_padding = [[0, 0]
                           ] + [[0, p] for p in node.attrs["output_padding"]]
          output_padding.insert(compute_c_idx, [0, 0])
          conv_rs = tf.pad(conv_rs, output_padding)

        # remove pads set in pads attr
        conv_rs_shape = conv_rs.get_shape().as_list()
        begin = [0] + pads[:spatial_size]
        begin.insert(compute_c_idx, 0)
        size = [
            s if d in ["N", "C"] else s - pads[spatial_format.find(d)] -
            pads[spatial_format.find(d) + spatial_size]
            for d, s in zip(compute_format, conv_rs_shape)
        ]
        conv_rs = tf.slice(conv_rs, begin=begin, size=size)

        convolved.append(conv_rs)
    else:
      if group != weights.shape[-1]:
        convolved = [
            tf.nn.convolution(
                x,
                weight,
                "VALID",
                strides=strides,
                dilation_rate=dilations,
                data_format=compute_format)
            for (x, weight) in zip(xs, weight_groups)
        ]
      else:
        # convert to depthwise convolutions if num group==channels
        convolved = [
          tf.nn.depthwise_conv2d(
               x,
               tf.transpose(weights, [0, 1, 3, 2]),  # [filter_height, filter_width, in_channels, multiplier (=1)]
               strides=_get_sequence(strides, 2, channel_index=3, name="strides"),  # requires a 4-d list
               padding="VALID",
               rate=dilations, # NOTE I'm not sure if it's a correct. In the newer tensorflow versions there is dilations parameter.
               data_format=compute_format,
           )
        ]

    if len(node.inputs) == 2:
      if support_cuda:
        output = tf.concat(convolved, axis=1)
      else:
        output = tf.concat(convolved, axis=-1)
        output = tf.transpose(
            output, perm=get_perm_from_formats(compute_format, storage_format))
    else:
      bias = input_dict[node.inputs[2]]
      bias = cls.explicit_broadcast([x, bias], compute_c_idx)

      if support_cuda:
        output = tf.concat(convolved, axis=1)
        output = tf.add(output, bias)
      else:
        output = tf.concat(convolved, axis=-1)
        output = tf.add(output, bias)
        output = tf.transpose(
            output, perm=get_perm_from_formats(compute_format, storage_format))

    return [output]
예제 #12
0
    def conv(cls, node, input_dict, transpose=False):
        """ Convolution method for both conv and transposed conv
    For transposed conv,
      Attr pads is not used for input, but declares how much output is padded.
      Here, output means output from transposed conv which already pad output_padding if set.
      So the pseudo explanation for output should be:
        output = conv_transpose_output + output_padding - pads
      And conv_transpose_output shape should be:
        conv_transpose_output_shape[i] = strides[i] * (input_shape[i] - 1) + kernel_shape[i]
    """
        x = input_dict[node.inputs[0]]
        x_rank = len(x.get_shape())
        x_shape = x.get_shape().as_list()
        spatial_size = x_rank - 2

        support_cuda = supports_device("CUDA")
        storage_format, compute_format = get_data_format(x_rank)
        compute_c_idx = compute_format.find("C")
        spatial_format = "".join(
            [d for d in compute_format if d not in ["N", "C"]])

        in_weights = input_dict[node.inputs[1]]
        weights_rank = len(in_weights.get_shape())
        if transpose:
            # Translate weights from (C x M x KH x KW) to (KH x KW X M X C)
            perm = list(range(2, weights_rank)) + [1, 0]
        else:
            # Translate weights from (M x C x KH x KW) to (KH x KW X C X M)
            perm = list(range(2, weights_rank)) + [1, 0]

        if "kernel_shape" in node.attrs.keys():
            kernel_shape = node.attrs["kernel_shape"]
            assert in_weights.get_shape().as_list()[2:] == kernel_shape, (
                "kernel_shape "
                "attr of convolution does not match the actual weight "
                "passed to this operation, attr {}, actual {}").format(
                    kernel_shape,
                    in_weights.get_shape().as_list())

        weights = tf.transpose(in_weights, perm)
        dilations = node.attrs.get("dilations", [1] * spatial_size)
        strides = node.attrs.get("strides", [1] * spatial_size)

        pads = node.attrs.get("pads", [0, 0] * spatial_size)

        # Check auto_pad nonexistent or NOTSET first
        if "auto_pad" not in node.attrs or node.attrs["auto_pad"] == "NOTSET":
            if not transpose:
                if pads != [0, 0] * spatial_size:
                    x = PadMixin.get_padding_as_op(x, pads)
                pad_mode = "VALID"
            else:
                pad_mode = "NOTSET"
        # Then we use auto_pad to setup pad_mode
        elif node.attrs["auto_pad"] == "SAME_UPPER":
            pad_mode = "SAME"
        elif node.attrs["auto_pad"] == "VALID":
            pad_mode = "VALID"
        elif node.attrs["auto_pad"] == "SAME_LOWER":
            pad_mode = PAD_TF_INCOMPATIBLE
        else:
            raise ValueError("Invalid auto_pad attribute: {}".format(
                node.attrs["auto_pad"]))

        # Currently auto_pad = SAME_LOWER is not supported
        if pad_mode is PAD_TF_INCOMPATIBLE:
            if transpose:
                exception.OP_UNSUPPORTED_EXCEPT(
                    "ConvTranspose with auto_pad `SAME_LOWER`", "Tensorflow")
            else:
                exception.OP_UNSUPPORTED_EXCEPT(
                    "Conv with auto_pad `SAME_LOWER`", "Tensorflow")

        group = node.attrs.get("group", 1)

        weight_groups = tf.split(weights, num_or_size_splits=group, axis=-1)

        if support_cuda:
            xs = tf.split(x, num_or_size_splits=group, axis=1)
        else:
            x = tf.transpose(x,
                             perm=get_perm_from_formats(
                                 storage_format, compute_format))
            xs = tf.split(x, num_or_size_splits=group, axis=-1)

        if transpose:
            if dilations != [1] * spatial_size:
                raise RuntimeError(
                    "Cannot set non-1 dilation for conv transpose.")
            convolved = []
            for (x, weight) in zip(xs, weight_groups):
                x_spatial_shape = [
                    x_shape[storage_format.find(d)] for d in spatial_format
                ]
                weights_shape = weights.get_shape().as_list()
                output_shape = node.attrs.get("output_shape", None)
                conv_output_shape = [x_shape[storage_format.find("N")]]

                # calculate output shape
                if pad_mode == "NOTSET":
                    if output_shape is None:
                        conv_output_shape += [
                            strides[i] * x_spatial_shape[i] +
                            max(weights_shape[i] - strides[i], 0)
                            for i in list(range(spatial_size))
                        ]
                    else:
                        conv_output_shape += [
                            s + pads[i] + pads[spatial_size + i]
                            for i, s in enumerate(output_shape[-2:])
                        ]
                    conv_output_shape.insert(compute_c_idx, weights_shape[-2])

                    # make strides to match input rank
                    strides_full = [1] + strides
                    strides_full.insert(compute_c_idx, 1)

                    # get corresponding function in tf
                    if spatial_size == 1:
                        conv_func = tf.nn.conv1d_transpose
                        strides_full = strides[0]
                    elif spatial_size == 2:
                        conv_func = tf.nn.conv2d_transpose
                    elif spatial_size == 3:
                        conv_func = tf.nn.conv3d_transpose
                    else:
                        raise NotImplementedError(
                            "Transposed convolution for {}d is not implemented in Tensorflow"
                            .format(spatial_size))

                    # use raw input x to do transposed conv
                    conv_rs = conv_func(x,
                                        weight,
                                        conv_output_shape,
                                        strides_full,
                                        padding="VALID",
                                        data_format=compute_format)

                    # pad output first by output_padding attr
                    if "output_padding" in node.attrs and output_shape is None:
                        output_padding = [[
                            0, 0
                        ]] + [[0, p] for p in node.attrs["output_padding"]]
                        output_padding.insert(compute_c_idx, [0, 0])
                        conv_rs = tf.pad(conv_rs, output_padding)

                    # remove pads set in pads attr
                    conv_rs_shape = conv_rs.get_shape().as_list()
                    begin = [0] + pads[:spatial_size]
                    begin.insert(compute_c_idx, 0)
                    size = [
                        s if d in ["N", "C"] else s -
                        pads[spatial_format.find(d)] -
                        pads[spatial_format.find(d) + spatial_size]
                        for d, s in zip(compute_format, conv_rs_shape)
                    ]
                    conv_rs = tf.slice(conv_rs, begin=begin, size=size)

                    convolved.append(conv_rs)
                else:
                    # No need to check pads if auto_pad is specifically provided.
                    # The assumption is that once auto_pad is provided as either VALID
                    # or SAME_UPPER (SAME_LOWER is currently not supported in TF) the
                    # output_shape will always be inferred. That is, the output_shape
                    # and output_padding will not be used in this case.
                    if pad_mode == "VALID":
                        conv_output_shape += [
                            strides[i] * (x_spatial_shape[i] - 1) +
                            weights_shape[i] for i in list(range(spatial_size))
                        ]
                    else:
                        conv_output_shape += [
                            strides[i] * x_spatial_shape[i]
                            for i in list(range(spatial_size))
                        ]
                    conv_output_shape.insert(compute_c_idx, weights_shape[-2])

                    # make strides to match input rank
                    strides_full = [1] + strides
                    strides_full.insert(compute_c_idx, 1)

                    # get corresponding function in tf
                    if spatial_size == 1:
                        conv_func = tf.contrib.nn.conv1d_transpose
                        strides_full = strides[0]
                    elif spatial_size == 2:
                        conv_func = tf.nn.conv2d_transpose
                    elif spatial_size == 3:
                        conv_func = tf.nn.conv3d_transpose
                    else:
                        raise NotImplementedError(
                            "Transposed convolution for {}d is not implemented in Tensorflow"
                            .format(spatial_size))

                    # use raw input x to do transposed conv
                    conv_rs = conv_func(x,
                                        weight,
                                        conv_output_shape,
                                        strides_full,
                                        padding=pad_mode,
                                        data_format=compute_format)
                    convolved.append(conv_rs)

        else:
            convolved = [
                tf.nn.convolution(x,
                                  weight,
                                  padding=pad_mode,
                                  strides=strides,
                                  dilations=dilations,
                                  data_format=compute_format)
                for (x, weight) in zip(xs, weight_groups)
            ]

        if len(node.inputs) == 2:
            if support_cuda:
                output = tf.concat(convolved, axis=1)
            else:
                output = tf.concat(convolved, axis=-1)
                output = tf.transpose(output,
                                      perm=get_perm_from_formats(
                                          compute_format, storage_format))
        else:
            bias = input_dict[node.inputs[2]]
            bias = cls.explicit_broadcast([x, bias], compute_c_idx)

            if support_cuda:
                output = tf.concat(convolved, axis=1)
                output = tf.add(output, bias)
            else:
                output = tf.concat(convolved, axis=-1)
                output = tf.add(output, bias)
                output = tf.transpose(output,
                                      perm=get_perm_from_formats(
                                          compute_format, storage_format))

        return [output]
예제 #13
0
    def pool(cls, node, input_dict, pool_func, pooling_type, strict=True):
        x = input_dict[node.inputs[0]]
        x_rank = len(x.get_shape())
        x_shape = x.get_shape().as_list()
        spatial_size = x_rank - 2

        support_cuda = supports_device("CUDA")
        storage_format, compute_format = get_data_format(x_rank)

        kernel_shape = node.attrs["kernel_shape"]
        strides = node.attrs.get("strides", [1] * spatial_size)
        pads = node.attrs.get("pads", None)
        pad = PAD_TF_INCOMPATIBLE
        # from version 7
        count_include_pad = node.attrs.get("count_include_pad", 0)

        # If padding is specified, try to recover it from explicit padding
        # specification to tensorflow padding mode:
        if pads is not None:
            pad = cls._get_tf_pad(x_shape[2:], kernel_shape, strides, pads)
        else:
            # Neither pad nor auto_pad is specified, assume no padding.
            if "auto_pad" not in node.attrs:
                pad = "VALID"
            # We consult auto_pad if pad is not specified and auto_pad
            # is available.
            else:
                if node.attrs["auto_pad"] == "SAME_UPPER":
                    pad = "SAME"
                elif node.attrs["auto_pad"] == "VALID":
                    pad = "VALID"
                elif node.attrs["auto_pad"] == "SAME_LOWER":
                    pad = PAD_TF_INCOMPATIBLE
                if count_include_pad == 1:
                    _, pads = cls._pool_get_shapes(node.attrs["auto_pad"],
                                                   x_shape[2:], kernel_shape,
                                                   strides,
                                                   [0] * spatial_size * 2)

        if strict and count_include_pad == 0:
            if pad is PAD_TF_INCOMPATIBLE:
                return cls._compatibility_pool(node, input_dict, pooling_type)
        else:
            if pads != [0] * spatial_size * 2:
                x = PadMixin.get_padding_as_op(x, pads)
            pad = "VALID"

        if support_cuda:
            pooled = pool_func(x,
                               kernel_shape,
                               padding=pad,
                               strides=strides,
                               data_format=compute_format)
        else:
            x = tf.transpose(x,
                             perm=get_perm_from_formats(
                                 storage_format, compute_format))
            pooled = pool_func(x,
                               kernel_shape,
                               padding=pad,
                               strides=strides,
                               data_format=compute_format)
            pooled = tf.transpose(pooled,
                                  perm=get_perm_from_formats(
                                      compute_format, storage_format))

        return [pooled]
예제 #14
0
    def pool(cls, node, input_dict, pooling_type, strict=True):
        x = input_dict[node.inputs[0]]
        orig_x = x

        kernel_shape = node.attrs["kernel_shape"]

        spatial_size = len(kernel_shape)
        x_rank = spatial_size + 2

        kernel_shape = node.attrs["kernel_shape"]
        strides = node.attrs.get("strides", [1] * spatial_size)
        dilations = node.attrs.get("dilations", [1] * spatial_size)
        ceil_mode = bool(node.attrs.get("ceil_mode", 0))
        pads = node.attrs.get("auto_pad", "NOTSET")
        p = node.attrs.get("p", 2)

        if pads == "NOTSET":
            pads = node.attrs.get("pads", [0] * spatial_size * 2)
            # In case shape is fully defined, check if pads match
            # SAME padding in Tensorflow
            if x.shape.is_fully_defined() and pads != [0] * spatial_size * 2:
                in_shape = x.get_shape().as_list()
                same_paddings = calc_pads_same(in_shape[1:x_rank - 1],
                                               kernel_shape, strides,
                                               dilations, "SAME_UPPER")
                if pads == same_paddings:
                    pads = "SAME_UPPER"

        count_include_pad = bool(node.attrs.get("count_include_pad", 0))
        if pooling_type == "AVG":
            pooling_name = "AveragePool"
        elif pooling_type == "MAX":
            pooling_name = "MaxPool"
        elif pooling_type == "MAX_WITH_ARGMAX":
            pooling_name = "MaxPoolWithArgmax"
        elif pooling_type == "LP":
            pooling_name = "LpPool"

        if spatial_size > 3:
            exception.OP_UNSUPPORTED_EXCEPT(
                pooling_name + " with {}D input".format(x_rank), "Tensorflow")
        if pooling_type == "MAX_WITH_ARGMAX" and x_rank != 4:
            exception.OP_UNSUPPORTED_EXCEPT(
                pooling_name + " with {}D input".format(x_rank), "Tensorflow")
        if node.attrs.get("storage_order", 0) != 0:
            exception.OP_UNSUPPORTED_EXCEPT(
                pooling_name + " with column major", "Tensorflow")

        storage_format, _ = get_data_format(x_rank)

        need_trans = storage_format.startswith("NC")
        if need_trans:
            compute_format = "N" + storage_format[2:] + "C"
            x = tf.transpose(x,
                             perm=get_perm_from_formats(
                                 storage_format, compute_format))

        dp = DilatedPooling(input=x,
                            kernel_shape=kernel_shape,
                            strides=strides,
                            dilations=dilations,
                            padding=pads,
                            ceil_mode=ceil_mode,
                            pooling_type=pooling_type,
                            count_include_pad=count_include_pad,
                            p=p)
        if not dp.is_supported():
            if strict:
                logger.warning(
                    "Using the pooling op in compatibility mode. "
                    "This means your graph cannot be serialized.", UserWarning)

                result = tf.py_func(py_pool, [
                    orig_x, kernel_shape, strides, dilations, pads, ceil_mode,
                    pooling_type, False
                ], orig_x.dtype)

                if orig_x.shape.is_fully_defined():
                    shape = orig_x.get_shape().as_list()
                    output_shape = shape[0:2] + calc_output_shape(
                        shape[2:x_rank], kernel_shape, strides, dilations,
                        pads, ceil_mode)
                else:
                    output_shape = [None] * x_rank
                result.set_shape(output_shape)
                return [result]
            else:
                exception.OP_UNSUPPORTED_EXCEPT(
                    "strict == 0 and " + pooling_name +
                    " arguments not compatible", "Tensorflow")

        def dilated_pool():
            return (dp.dilated_pool(), None)

        # select correct op depending on the pooling type
        pooling_op = dilated_pool if pooling_type in ["MAX", "AVG", "LP"] else \
            dp.dilated_maxpool_with_argmax

        # select the correct transpose ops depending on the input storage format
        perm = get_perm_from_formats(compute_format, storage_format)

        def postprocess(pooled, argmax):
            return (tf.transpose(pooled, perm=perm) if need_trans else pooled,
                    tf.transpose(argmax, perm=perm)
                    if need_trans and argmax is not None else argmax)

        pooled, argmax = pooling_op()
        pooled, argmax = postprocess(pooled, argmax)

        result = [pooled] if argmax is None else [pooled, argmax]

        return result
예제 #15
0
    def pool(cls, node, input_dict, pool_func, pooling_type, strict=True):
        x = input_dict[node.inputs[0]]
        x_rank = len(x.get_shape())
        x_shape = x.get_shape().as_list()
        spatial_size = x_rank - 2

        if spatial_size > 3:
            exception.OP_UNSUPPORTED_EXCEPT(
                "MaxPool with {}D input".format(x_rank), "Tensorflow")

        support_cuda = supports_device("CUDA")
        storage_format, compute_format = get_data_format(x_rank)

        kernel_shape = node.attrs["kernel_shape"]
        strides = node.attrs.get("strides", [1] * spatial_size)
        pads = node.attrs.get("pads", None)
        pad = PAD_TF_INCOMPATIBLE
        # from version 7
        count_include_pad = node.attrs.get("count_include_pad", 0)

        auto_pad = node.attrs.get("auto_pad", "NOTSET")
        # if auto_pad is NOTSET, we check pads
        if auto_pad == "NOTSET":
            # If padding is specified, try to recover it from explicit padding
            # specification to tensorflow padding mode:
            if pads is not None:
                pad = cls._get_tf_pad(x_shape[2:], kernel_shape, strides, pads)
            else:
                pad = "VALID"
        else:
            if auto_pad == "SAME_UPPER":
                pad = "SAME"
            elif auto_pad == "VALID":
                pad = "VALID"
            elif auto_pad == "SAME_LOWER":
                pad = PAD_TF_INCOMPATIBLE
            if count_include_pad == 1:
                _, pads = cls._pool_get_shapes(auto_pad, x_shape[2:],
                                               kernel_shape, strides,
                                               [0] * spatial_size * 2)

        if pooling_type in ("AVG", "MAX"):
            if strict and count_include_pad == 0:
                if pad is PAD_TF_INCOMPATIBLE:
                    return cls._compatibility_pool(node, input_dict,
                                                   pooling_type)
            else:
                if pads != [0] * spatial_size * 2:
                    x = PadMixin.get_padding_as_op(x, pads)
                pad = "VALID"
        elif pooling_type == "MAX_WITH_ARGMAX":
            if pad is PAD_TF_INCOMPATIBLE:
                exception.OP_UNSUPPORTED_EXCEPT(
                    "MaxPoolWithArgmax with pad is None or incompatible mode",
                    "Tensorflow")
            if x_rank != 4:
                exception.OP_UNSUPPORTED_EXCEPT(
                    "MaxPoolWithArgmax with {}D input".format(x_rank),
                    "Tensorflow")
            if node.attrs.get("storage_order", 0) != 0:
                exception.OP_UNSUPPORTED_EXCEPT(
                    "MaxPoolWithArgmax with column major", "Tensorflow")

            need_trans = storage_format != "NHWC"
            if need_trans:
                x = tf.transpose(x,
                                 perm=get_perm_from_formats(
                                     storage_format, "NHWC"))
            pooled, argmax = pool_func(x, [1] + kernel_shape + [1],
                                       padding=pad,
                                       strides=[1] + strides + [1])
            if need_trans:
                pooled = tf.transpose(pooled,
                                      perm=get_perm_from_formats(
                                          "NHWC", storage_format))
                argmax = tf.transpose(argmax,
                                      perm=get_perm_from_formats(
                                          "NHWC", storage_format))

            return [pooled, argmax]

        if support_cuda:
            pooled = pool_func(x,
                               kernel_shape,
                               padding=pad,
                               strides=strides,
                               data_format=compute_format)
        else:
            x = tf.transpose(x,
                             perm=get_perm_from_formats(
                                 storage_format, compute_format))
            pooled = pool_func(x,
                               kernel_shape,
                               padding=pad,
                               strides=strides,
                               data_format=compute_format)
            pooled = tf.transpose(pooled,
                                  perm=get_perm_from_formats(
                                      compute_format, storage_format))

        return [pooled]
예제 #16
0
  def pool_v11(cls, node, input_dict, pooling_type, strict=True):
    x = input_dict[node.inputs[0]]
    orig_x = x

    kernel_shape = node.attrs["kernel_shape"]

    spatial_size = len(kernel_shape)
    x_rank = spatial_size + 2

    kernel_shape = node.attrs["kernel_shape"]
    strides = node.attrs.get("strides", [1] * spatial_size)
    dilations = node.attrs.get("dilations", [1] * spatial_size)
    ceil_mode = bool(node.attrs.get("ceil_mode", 0))
    pads = node.attrs.get("auto_pad", "NOTSET")
    if pads == "NOTSET":
      pads = node.attrs.get("pads", [0] * spatial_size * 2)

    count_include_pad = bool(node.attrs.get("count_include_pad", 0))
    if pooling_type == "AVG":
        pooling_name = "AveragePool"
    elif pooling_type == "MAX":
        pooling_name = "MaxPool"
    elif pooling_type == "MAX_WITH_ARGMAX":
        pooling_name = "MaxPoolWithArgmax"

    if spatial_size > 3:
      exception.OP_UNSUPPORTED_EXCEPT(
          pooling_name + " with {}D input".format(x_rank), "Tensorflow")
    if pooling_type == "MAX_WITH_ARGMAX" and x_rank != 4:
      exception.OP_UNSUPPORTED_EXCEPT(
          pooling_name + " with {}D input".format(x_rank), "Tensorflow")
    if node.attrs.get("storage_order", 0) != 0:
      exception.OP_UNSUPPORTED_EXCEPT(pooling_name + " with column major",
                                      "Tensorflow")

    storage_format, _ = get_data_format(x_rank)

    need_trans = storage_format.startswith("NC")
    if need_trans:
      compute_format = "N" + storage_format[2:] + "C"
      x = tf.transpose(x, perm=get_perm_from_formats(storage_format,
                                                     compute_format))

    dp = DilatedPooling(input=x, kernel_shape=kernel_shape, strides=strides,
                        dilations=dilations, padding=pads, ceil_mode=ceil_mode,
                        pooling_type=pooling_type,
                        count_include_pad=count_include_pad)
    if not dp.is_supported():
      if strict:
        warnings.warn(
            "Using the pooling op in compatibility mode. "
            "This means your graph cannot be serialized.", UserWarning)

        return [tf.py_func(py_pool, [orig_x, kernel_shape, strides,
                                     dilations, pads, ceil_mode, "AVG",
                                     False], orig_x.dtype)]
      else:
        exception.OP_UNSUPPORTED_EXCEPT("strict == 0 and average pool"
                                        " arguments not compatible",
                                        "Tensorflow")

    def dilated_pool():
      return (dp.dilated_pool(), None)

    # select correct op depending on the pooling type
    pooling_op = dilated_pool if pooling_type in ["MAX", "AVG"] else \
        dp.dilated_maxpool_with_argmax

    # select the correct transpose ops depending on the input storage format
    perm = get_perm_from_formats(compute_format, storage_format)

    def postprocess(pooled, argmax):
      return (tf.transpose(pooled, perm=perm) if need_trans else pooled,
              tf.transpose(argmax, perm=perm) if need_trans and argmax
              is not None else argmax)

    pooled, argmax = pooling_op()
    pooled, argmax = postprocess(pooled, argmax)

    result = [pooled] if argmax is None else [pooled, argmax]

    return result