def __init__(self, ksizes, strides, rates, padding="valid"): super(Unfold, self).__init__() self.extract_image_patches = inner.ExtractImagePatches(ksizes, strides, rates, padding) self.transpose = P.Transpose() self.format_NHWC = (0, 2, 3, 1) self.format_NCHW = (0, 3, 1, 2) self.is_ge = context.get_context("enable_ge")
def __init__(self, ksizes, strides, rates, padding="valid"): super(Unfold, self).__init__() def _check_tuple_or_list(arg_name, arg_val, prim_name): Validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.cls_name) if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: raise ValueError( f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " f"{arg_name}_col, 1], but got {arg_val}.") if not isinstance(arg_val[1], int) or not isinstance( arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: raise ValueError( f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " f"is {arg_val[2]}") _check_tuple_or_list("ksize", ksizes, self.cls_name) _check_tuple_or_list("stride", strides, self.cls_name) _check_tuple_or_list("rate", rates, self.cls_name) ksizes = ksizes[0], ksizes[3], ksizes[1], ksizes[2] strides = strides[0], strides[3], strides[1], strides[2] rates = rates[0], rates[3], rates[1], rates[2] self.extract_image_patches = inner.ExtractImagePatches( ksizes, strides, rates, padding)