Beispiel #1
0
    def forward(self,
                xyz: paddle.Tensor,
                new_xyz: paddle.Tensor,
                features: paddle.Tensor = None):
        """forward.

        Args:
            xyz (Tensor): (B, N, 3) xyz coordinates of the features.
            new_xyz (Tensor): Ignored.
            features (Tensor): (B, C, N) features to group.

        Return:
            Tensor: (B, C + 3, 1, N) Grouped feature.
        """
        grouped_xyz = xyz.transpose((1, 2)).unsqueeze(2)
        if features is not None:
            grouped_features = features.unsqueeze(2)
            if self.use_xyz:
                new_features = paddle.concat([grouped_xyz, grouped_features],
                                             axis=1)  # (B, 3 + C, 1, N)
            else:
                new_features = grouped_features
        else:
            new_features = grouped_xyz

        return new_features
Beispiel #2
0
    def forward(
        self,
        query: paddle.Tensor,
        key: paddle.Tensor,
        value: paddle.Tensor,
        mask: Optional[paddle.Tensor] = None,
    ) -> Tuple[paddle.Tensor, paddle.Tensor]:
        batch_size = query.shape[0]

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # multi head
        query = query.reshape((batch_size, -1, self.num_attention_heads,
                               self.dims_per_head)).transpose((0, 2, 1, 3))
        key = key.reshape((batch_size, -1, self.num_attention_heads,
                           self.dims_per_head)).transpose((0, 2, 1, 3))
        value = value.reshape((batch_size, -1, self.num_attention_heads,
                               self.dims_per_head)).transpose((0, 2, 1, 3))

        # self attention
        context, attention = self.attention(query, key, value, attn_mask=mask)
        # concat heads
        context = context.transpose((0, 2, 1, 3)).reshape(
            (batch_size, -1, self.hidden_size))
        output = self.dense(context)

        return output, attention
Beispiel #3
0
 def forward(self, cosine: paddle.Tensor, label):
     m_hot = paddle.nn.functional.one_hot(label.astype('long'),
                                          num_classes=85742) * self.m
     cosine = cosine.acos()
     cosine += m_hot
     cosine = cosine.cos() * self.s
     return cosine
Beispiel #4
0
    def forward(ctx, target: paddle.Tensor,
                source: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]:
        """Find the top-3 nearest neighbors of the target set from the source
        set.

        Args:
            target (Tensor): shape (B, N, 3), points set that needs to
                find the nearest neighbors.
            source (Tensor): shape (B, M, 3), points set that is used
                to find the nearest neighbors of points in target set.

        Returns:
            Tensor: shape (B, N, 3), L2 distance of each point in target
                set to their corresponding nearest neighbors.
        """

        B, N, _ = target.size()
        m = source.size(1)
        dist2 = paddle.zeros((B, N, 3), dtype=paddle.float32)
        idx = paddle.zeros((B, N, 3), dtype=paddle.int64)

        interpolate_ops.three_nn_wrapper(B, N, m, target, source, dist2, idx)

        idx.stop_gradient = True

        return paddle.sqrt(dist2), idx
Beispiel #5
0
def make_non_pad_mask(lengths: paddle.Tensor) -> paddle.Tensor:
    batch_size = int(lengths.shape[0])
    max_len = int(lengths.max())
    seq_range = paddle.arange(0, max_len, dtype=paddle.int64)
    seq_range_expand = seq_range.unsqueeze(0).expand([batch_size, max_len])
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask.logical_not()
Beispiel #6
0
def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool,
                      out_dtype: paddle.dtype):
    if need_squeeze:
        img = img.squeeze(dim=0)

    if need_cast:
        if out_dtype in (paddle.uint8, paddle.int8, paddle.int16, paddle.int32,
                         paddle.int64):
            # it is better to round before cast
            img = paddle.round(img)
        img = img.as_type(out_dtype)

    return img
Beispiel #7
0
def normalize(tensor: Tensor,
              mean: List[float],
              std: List[float],
              inplace: bool = False) -> Tensor:
    """Normalize a float tensor image with mean and standard deviation.
    This transform does not support PIL Image.

    .. note::
        This transform acts out of place by default, i.e., it does not mutates the input tensor.

    See :class:`~paddlevision.transforms.Normalize` for more details.

    Args:
        tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
        inplace(bool,optional): Bool to make this operation inplace.

    Returns:
        Tensor: Normalized Tensor image.
    """
    if not isinstance(tensor, paddle.Tensor):
        raise TypeError(
            'Input tensor should be a paddle tensor. Got {}.'.format(
                type(tensor)))

    if not tensor.dtype in (paddle.float16, paddle.float32, paddle.float64):
        raise TypeError(
            'Input tensor should be a float tensor. Got {}.'.format(
                tensor.dtype))

    if tensor.ndim < 3:
        raise ValueError(
            'Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.shape() = '
            '{}.'.format(tensor.shape))

    if not inplace:
        tensor = tensor.clone()

    dtype = tensor.dtype
    mean = paddle.to_tensor(mean, dtype=dtype, place=tensor.place)
    std = paddle.to_tensor(std, dtype=dtype, place=tensor.place)
    if (std == 0).any():
        raise ValueError('std evaluated to zero, leading to division by zero.')
    if mean.ndim == 1:
        mean = mean.reshape((-1, 1, 1))
    if std.ndim == 1:
        std = std.reshape((-1, 1, 1))
    tensor = tensor.subtract(mean).divide(std)
    return tensor
Beispiel #8
0
def magphase(x: Tensor) -> Tuple[Tensor, Tensor]:
    """Compute compext norm of a given tensor.
    Typically,the input tensor is the result of a complex Fourier transform.
    Parameters:
        x(Tensor): The input tensor of shape (..., 2).
    Returns:
        The tuple of magnitude and phase.

    Shape:
        x: the shape of x is arbitrary, with the shape of last axis being 2
        outputs: the shapes of magnitude and phase are both input.shape[:-1]

     Examples:

        .. code-block:: python

        import paddle
        import paddleaudio.functional as F
        x = paddle.randn((10, 10, 2))
        angle, phase = F.magphase(x)

    """
    if x.shape[-1] != 2:
        raise ParameterError(
            f'complex tensor must be of shape (..., 2), but received {x.shape} instead'
        )
    mag = paddle.sqrt(paddle.square(x).sum(axis=-1))
    x0 = x.reshape((-1, 2))
    phase = paddle.atan2(x0[:, 0], x0[:, 1])
    phase = phase.reshape(x.shape[:-1])

    return mag, phase
Beispiel #9
0
    def forward(self, spectrum: Tensor, signal_length: int) -> Tensor:

        assert spectrum.ndim == 3 or spectrum.ndim == 4, (
            f'The input spectrum must be a 3-d or 4-d tensor, ' +
            f'but received ndim={spectrum.ndim} instead')

        if spectrum.ndim == 3:
            spectrum = spectrum.unsqueeze(0)

        bs, freq_dim, frame_num, complex_dim = spectrum.shape

        assert freq_dim == self.n_fft or freq_dim == self.n_fft // 2 + 1, (
            f'The input spectrum should have {self.n_fft} ' +
            f'or {self.n_fft//2+1} frequency ' +
            f'components, but received {freq_dim} instead')
        assert complex_dim == 2, (
            f'The last dimension of input spectrum should be 2 for ' +
            f'storing real and imaginary part of spectrum, ' +
            f'but received {complex_dim} instead')
        real = spectrum[:, :, :, 0]
        imag = spectrum[:, :, :, 1]
        if real.shape[1] == self.n_fft:
            real_full = real
            imag_full = imag
        else:
            real_full = paddle.concat([real, real[:, -2:0:-1]], 1)
            imag_full = paddle.concat([imag, -imag[:, -2:0:-1]], 1)
        part1 = paddle.matmul(self.idft_mat[:, :, :, :, 0], real_full)
        part2 = paddle.matmul(self.idft_mat[:, :, :, :, 1], imag_full)
        frames = part1[0] - part2[0]
        signal = F.deframe(frames, self.n_fft, self.hop_length,
                           self.win_length, signal_length)
        return signal
Beispiel #10
0
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
    # padding is left, right, top, bottom

    # crop if needed
    if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0:
        crop_left, crop_right, crop_top, crop_bottom = [
            -min(x, 0) for x in padding
        ]
        img = img[..., crop_top:img.shape[-2] - crop_bottom,
                  crop_left:img.shape[-1] - crop_right]
        padding = [max(x, 0) for x in padding]

    in_sizes = img.size()

    x_indices = [i for i in range(in_sizes[-1])]  # [0, 1, 2, 3, ...]
    left_indices = [i for i in range(padding[0] - 1, -1, -1)
                    ]  # e.g. [3, 2, 1, 0]
    right_indices = [-(i + 1) for i in range(padding[1])]  # e.g. [-1, -2, -3]
    x_indices = paddle.to_tensor(left_indices + x_indices + right_indices,
                                 device=img.device)

    y_indices = [i for i in range(in_sizes[-2])]
    top_indices = [i for i in range(padding[2] - 1, -1, -1)]
    bottom_indices = [-(i + 1) for i in range(padding[3])]
    y_indices = paddle.to_tensor(top_indices + y_indices + bottom_indices,
                                 device=img.device)

    ndim = img.ndim
    if ndim == 3:
        return img[:, y_indices[:, None], x_indices[None, :]]
    elif ndim == 4:
        return img[:, :, y_indices[:, None], x_indices[None, :]]
    else:
        raise RuntimeError(
            "Symmetric padding of N-D tensors are not supported yet")
Beispiel #11
0
def gram_matrix(data: paddle.Tensor) -> paddle.Tensor:
    """Get gram matrix"""
    b, ch, h, w = data.shape
    features = data.reshape((b, ch, w * h))
    features_t = features.transpose((0, 2, 1))
    gram = features.bmm(features_t) / (ch * h * w)
    return gram
Beispiel #12
0
def _cast_squeeze_in(
        img: Tensor, req_dtypes: List[paddle.dtype]
) -> Tuple[Tensor, bool, bool, paddle.dtype]:
    need_squeeze = False
    # make image NCHW
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if out_dtype not in req_dtypes:
        need_cast = True
        req_dtype = req_dtypes[0]
        img = img.as_type(req_dtype)
    return img, need_cast, need_squeeze, out_dtype
Beispiel #13
0
    def forward(self, X: paddle.Tensor):
        # input X is a 3D feature map
        self.P = paddle.bmm(self.weight.expand_as(self.G), self.G)

        x = paddle.bmm(
            self.P.transpose((0, 2, 1)).expand((X.shape[0], self.C, self.C)),
            X.reshape((X.shape[0], X.shape[1], -1))).reshape(X.shape)
        return x
Beispiel #14
0
def normalize_kernel2d(input: paddle.Tensor) -> paddle.Tensor:
    r"""Normalizes both derivative and smoothing kernel.
    """
    if len(input.shape) < 2:
        raise TypeError("input should be at least 2D tensor. Got {}".format(
            input.shape))
    norm: paddle.Tensor = input.abs().sum(-1).sum(-1)
    return input / (norm.unsqueeze(-1).unsqueeze(-1))
Beispiel #15
0
    def get_extended_attention_mask(
        self,
        attention_mask: paddle.Tensor,
        input_ids: paddle.Tensor,
    ) -> paddle.Tensor:
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask.unsqueeze(1)
        elif attention_mask.dim() == 2:
            extended_attention_mask = attention_mask.unsqueeze((1, 2))
        else:
            raise ValueError("Wrong shape for input_ids (shape {}) "
                             "or attention_mask (shape {})".format(
                                 input_ids.shape, attention_mask.shape))

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        return extended_attention_mask
Beispiel #16
0
    def forward(self, x: Tensor) -> Tensor:
        assert x.ndim == 2, (f'the input tensor must be 2d tensor, ' +
                             f'but received x.ndim={x.ndim}')

        weight = self.rir_source()  #get next weight
        pad_len = [
            weight.shape[-1] // 2 - 1, weight.shape[-1] - weight.shape[-1] // 2
        ]
        out = paddle.nn.functional.conv1d(x.unsqueeze(1),
                                          weight,
                                          padding=pad_len)
        return out[:, 0, :]
Beispiel #17
0
def noise_estimation_loss(model,
                          x0: paddle.Tensor,
                          t: paddle.Tensor,
                          e: paddle.Tensor,
                          b: paddle.Tensor,
                          keepdim=False):
    a = (1 - b).cumprod(0).index_select(t, 0).reshape((-1, 1, 1, 1))
    x = x0 * a.sqrt() + e * (1.0 - a).sqrt()
    output = model(x, t.astype('float32'))
    if keepdim:
        return (e - output).square().sum((1, 2, 3))
    else:
        return (e - output).square().sum((1, 2, 3)).mean(0)
Beispiel #18
0
    def forward(ctx, features: paddle.Tensor, indices: paddle.Tensor,
                weight: paddle.Tensor) -> paddle.Tensor:
        """Performs weighted linear interpolation on 3 features.

        Args:
            features (Tensor): (B, C, M) Features descriptors to be
                interpolated from
            indices (Tensor): (B, n, 3) index three nearest neighbors
                of the target features in features
            weight (Tensor): (B, n, 3) weights of interpolation

        Returns:
            Tensor: (B, C, N) tensor of the interpolated features
        """

        B, c, m = features.size()
        n = indices.size(1)
        ctx.three_interpolate_for_backward = (indices, weight, m)
        output = paddle.zeros((B, c, n), dtype=paddle.float32)

        interpolate_ops.three_interpolate_wrapper(B, c, m, n, features,
                                                  indices, weight, output)
        return output
Beispiel #19
0
    def forward(self, x: Tensor):

        assert x.ndim in [
            1, 2
        ], (f'The input signal x must be a 1-d tensor for ' +
            'non-batched signal or 2-d tensor for batched signal, ' +
            f'but received ndim={input.ndim} instead')
        if x.ndim == 1:
            x = x.unsqueeze((0, 1))
        elif x.ndim == 2:
            x = x.unsqueeze(1)

        if self.center:
            x = paddle.nn.functional.pad(
                x,
                pad=[self.n_fft // 2, self.n_fft // 2],
                mode=self.pad_mode,
                data_format="NCL")
        signal = self.conv(x)
        signal = signal.transpose([0, 2, 1])
        signal = signal.reshape(
            [signal.shape[0], signal.shape[1], signal.shape[2] // 2, 2])
        signal = signal.transpose((0, 2, 1, 3))
        return signal
Beispiel #20
0
    def _get_feat_extract_output_lengths(self, input_lengths: paddle.Tensor):
        """
        Computes the output length of the convolutional layers
        """
        def _conv_out_length(input_length, kernel_size, stride):
            # 1D convolutional layer output length formula taken
            # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1D.html
            return (input_length - kernel_size) // stride + 1

        for kernel_size, stride in zip(self.config.conv_kernel,
                                       self.config.conv_stride):
            input_lengths = _conv_out_length(input_lengths, kernel_size,
                                             stride)

        return input_lengths.astype('int64')
Beispiel #21
0
    def forward(self,
                query: paddle.Tensor,
                keys: paddle.Tensor,
                attn_mask=None):
        # query shape: batch x query_size
        # keys shape: batch x num keys x key_size

        # query_expanded shape: batch x num keys x query_size
        query_expanded = query.unsqueeze(1).expand(
            [query.shape[0], keys.shape[1], query.shape[-1]])

        # scores shape: batch x num keys x 1
        attn_logits = self.compute_scores(
            # shape: batch x num keys x query_size + key_size
            paddle.concat((query_expanded, keys), axis=2))
        # scores shape: batch x num keys
        attn_logits = attn_logits.squeeze(2)
        maybe_mask(attn_logits, attn_mask)
        return attn_logits
Beispiel #22
0
    def backward(ctx, grad_out: paddle.Tensor):
        """Backward of three interpolate.

        Args:
            grad_out (Tensor): (B, C, N) tensor with gradients of outputs

        Returns:
            Tensor: (B, C, M) tensor with gradients of features
        """
        idx, weight, m = ctx.three_interpolate_for_backward
        B, c, n = grad_out.size()

        grad_features = paddle.zeros((B, c, m), dtype=paddle.float32)
        grad_out_data = grad_out.data

        interpolate_ops.three_interpolate_grad_wrapper(B, c, n, m,
                                                       grad_out_data, idx,
                                                       weight,
                                                       grad_features.data)
        return grad_features, None, None
Beispiel #23
0
 def forward(
     self,
     query: paddle.Tensor,
     key: paddle.Tensor,
     value: paddle.Tensor,
     attn_mask: Optional[paddle.Tensor] = None,
 ) -> Tuple[paddle.Tensor, paddle.Tensor]:
     r"""
     Args:
         query: [batch, num_attention_heads, len_query, dim_query]
         key: [batch, num_attention_heads, len_key, dim_key]
         value: [batch, num_attention_heads, len_value, dim_value]
         attn_mask: [batch, num_attention_heads, len_query, len_key]
     """
     attention = paddle.matmul(query, key.transpose((0, 1, 3, 2)))
     attention = attention / math.sqrt(query.shape[-1])
     if attn_mask is not None:
         attention = attention + attn_mask
     attention = nn.Softmax(axis=-1)(attention)
     attention = self.dropout(attention)
     context = paddle.matmul(attention, value)
     return context, attention
    def forward(ctx,
                k: int,
                xyz: paddle.Tensor,
                center_xyz: paddle.Tensor = None,
                transposed: bool = False) -> paddle.Tensor:
        """Forward.

        Args:
            k (int): number of nearest neighbors.
            xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
                xyz coordinates of the features.
            center_xyz (Tensor): (B, npoint, 3) if transposed == False,
                else (B, 3, npoint). centers of the knn query.
            transposed (bool): whether the input tensors are transposed.
                defaults to False. Should not expicitly use this keyword
                when calling knn (=KNN.apply), just add the fourth param.

        Returns:
            Tensor: (B, k, npoint) tensor with the indicies of
                the features that form k-nearest neighbours.
        """
        assert k > 0

        if center_xyz is None:
            center_xyz = xyz

        if transposed:
            xyz = xyz.transpose((2, 1))
            center_xyz = center_xyz.transpose((2, 1))

        center_xyz_device = center_xyz.get_device()
        assert center_xyz_device == xyz.get_device(), \
            'center_xyz and xyz should be put on the same device'

        B, npoint, _ = center_xyz.shape
        N = xyz.shape[1]

        idx = center_xyz.new_zeros((B, npoint, k)).int()
        dist2 = center_xyz.new_zeros((B, npoint, k)).float()

        knn_ops.knn_wrapper(B, N, npoint, k, xyz, center_xyz, idx, dist2)
        # idx shape to [B, k, npoint]
        idx = idx.transpose((2, 1))
        ctx.mark_non_differentiable(idx)
        return idx
Beispiel #25
0
    def forward(ctx, features: paddle.Tensor,
                indices: paddle.Tensor) -> paddle.Tensor:
        """forward.

        Args:
            features (Tensor): (B, C, N) features to gather.
            indices (Tensor): (B, M) where M is the number of points.

        Returns:
            Tensor: (B, C, M) where M is the number of points.
        """

        B, npoint = indices.shape
        _, C, N = features.shape
        output = paddle.zeros([B, C, npoint], dtype=paddle.float32)

        gather_points_ops.gather_points_wrapper(B, C, N, npoint, features,
                                                indices, output)

        ctx.save_for_backward(indices, C, N)
        indices.stop_gradient = True

        return output
Beispiel #26
0
 def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int):
     return tensor.reshape(
         (bsz, seq_len, self.num_heads, self.head_dim)).transpose(
             (0, 2, 1, 3))  #?.contiguous()
Beispiel #27
0
def zero_(tensor: Tensor):
    return tensor.set_value(paddle.zeros_like(tensor))
Beispiel #28
0
def fill_(tensor: Tensor, value):
    return tensor.set_value(paddle.full_like(tensor, value))
Beispiel #29
0
def pad(img: Tensor,
        padding: List[int],
        fill: int = 0,
        padding_mode: str = "constant") -> Tensor:
    _assert_image_tensor(img)

    if not isinstance(padding, (int, tuple, list)):
        raise TypeError("Got inappropriate padding arg")
    if not isinstance(fill, (int, float)):
        raise TypeError("Got inappropriate fill arg")
    if not isinstance(padding_mode, str):
        raise TypeError("Got inappropriate padding_mode arg")

    if isinstance(padding, tuple):
        padding = list(padding)

    if isinstance(padding, list) and len(padding) not in [1, 2, 4]:
        raise ValueError(
            "Padding must be an int or a 1, 2, or 4 element tuple, not a " +
            "{} element tuple".format(len(padding)))

    if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
        raise ValueError(
            "Padding mode should be either constant, edge, reflect or symmetric"
        )

    if isinstance(padding, int):
        pad_left = pad_right = pad_top = pad_bottom = padding
    elif len(padding) == 1:
        pad_left = pad_right = pad_top = pad_bottom = padding[0]
    elif len(padding) == 2:
        pad_left = pad_right = padding[0]
        pad_top = pad_bottom = padding[1]
    else:
        pad_left = padding[0]
        pad_top = padding[1]
        pad_right = padding[2]
        pad_bottom = padding[3]

    p = [pad_left, pad_right, pad_top, pad_bottom]

    if padding_mode == "edge":
        # remap padding_mode str
        padding_mode = "replicate"
    elif padding_mode == "symmetric":
        # route to another implementation
        return _pad_symmetric(img, p)

    need_squeeze = False
    if img.ndim < 4:
        img = img.unsqueeze(dim=0)
        need_squeeze = True

    out_dtype = img.dtype
    need_cast = False
    if (padding_mode != "constant") and img.dtype not in (paddle.float32,
                                                          paddle.float64):
        # Here we temporary cast input tensor to float
        need_cast = True
        img = img.as_type(paddle.float32)

    img = paddle_pad(img, p, mode=padding_mode, value=float(fill))

    if need_squeeze:
        img = img.squeeze(axis=0)

    if need_cast:
        img = img.as_type(out_dtype)

    return img
Beispiel #30
0
def resize(img: Tensor,
           size: List[int],
           interpolation: str = "bilinear",
           max_size: Optional[int] = None,
           antialias: Optional[bool] = None) -> Tensor:
    _assert_image_tensor(img)

    if not isinstance(size, (int, tuple, list)):
        raise TypeError("Got inappropriate size arg")
    if not isinstance(interpolation, str):
        raise TypeError("Got inappropriate interpolation arg")

    if interpolation not in ["nearest", "bilinear", "bicubic"]:
        raise ValueError(
            "This interpolation mode is unsupported with Tensor input")

    if isinstance(size, tuple):
        size = list(size)

    if isinstance(size, list):
        if len(size) not in [1, 2]:
            raise ValueError(
                "Size must be an int or a 1 or 2 element tuple/list, not a "
                "{} element tuple/list".format(len(size)))
        if max_size is not None and len(size) != 1:
            raise ValueError(
                "max_size should only be passed if size specifies the length of the smaller edge."
            )

    if antialias is None:
        antialias = False

    if antialias and interpolation not in ["bilinear", "bicubic"]:
        raise ValueError(
            "Antialias option is supported for bilinear and bicubic interpolation modes only"
        )

    w, h = _get_image_size(img)

    if isinstance(size, int) or len(
            size) == 1:  # specified size only for the smallest edge
        short, long = (w, h) if w <= h else (h, w)
        requested_new_short = size if isinstance(size, int) else size[0]

        if short == requested_new_short:
            return img

        new_short, new_long = requested_new_short, int(requested_new_short *
                                                       long / short)

        if max_size is not None:
            if max_size <= requested_new_short:
                raise ValueError(
                    f"max_size = {max_size} must be strictly greater than the requested "
                    f"size for the smaller edge size = {size}")
            if new_long > max_size:
                new_short, new_long = int(max_size * new_short /
                                          new_long), max_size

        new_w, new_h = (new_short, new_long) if w <= h else (new_long,
                                                             new_short)

    else:  # specified both h and w
        new_w, new_h = size[1], size[0]

    img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
        img, [paddle.float32, paddle.float64])

    # Define align_corners to avoid warnings
    align_corners = False if interpolation in ["bilinear", "bicubic"] else None

    img = interpolate(img,
                      size=[new_h, new_w],
                      mode=interpolation,
                      align_corners=align_corners)

    if interpolation == "bicubic" and out_dtype == paddle.uint8:
        img = img.clamp(min=0, max=255)

    img = _cast_squeeze_out(img,
                            need_cast=need_cast,
                            need_squeeze=need_squeeze,
                            out_dtype=out_dtype)

    return img