Exemple #1
0
 def __init__(self, ksizes, strides, rates, padding="valid"):
     super(Unfold, self).__init__()
     self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
     self.transpose = P.Transpose()
     self.format_NHWC = (0, 2, 3, 1)
     self.format_NCHW = (0, 3, 1, 2)
     self.is_ge = context.get_context("enable_ge")
Exemple #2
0
    def __init__(self, ksizes, strides, rates, padding="valid"):
        super(Unfold, self).__init__()

        def _check_tuple_or_list(arg_name, arg_val, prim_name):
            Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list],
                                       self.cls_name)
            if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1:
                raise ValueError(
                    f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
                    f"{arg_name}_col, 1], but got {arg_val}.")
            if not isinstance(arg_val[1], int) or not isinstance(
                    arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1:
                raise ValueError(
                    f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an "
                    f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col "
                    f"is {arg_val[2]}")

        _check_tuple_or_list("ksize", ksizes, self.cls_name)
        _check_tuple_or_list("stride", strides, self.cls_name)
        _check_tuple_or_list("rate", rates, self.cls_name)
        ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2]
        strides = strides[0], strides[3], strides[1], strides[2]
        rates = rates[0], rates[3], rates[1], rates[2]
        self.extract_image_patches = inner.ExtractImagePatches(
            ksizes, strides, rates, padding)