def _get_output_shape(self): """Get output shape""" shape = self.inputs[0].shape multiples = self.attrs["multiples"] shape = list(shape) multiples = list(multiples) diff_len = len(multiples) - len(shape) if diff_len < 0: raise GKException( "Dimensions of multiples{} < dimensions of input{} in Tile". format(multiples, shape)) if diff_len > 0: for _ in range(diff_len): shape.insert(0, 1) output_shape = [] for sh, mul in list(zip(shape, multiples)): if sh != 1 and mul != 1: raise GKException( "Tile op in expander only Support Automatic Broadcast!") dim = sh * mul output_shape.append(dim) return output_shape
def _check_format(obj): inp_formats = [inp['format'] for inp in obj.inputs] for formats in getattr(obj, format_list_name): if len(formats) != len(inp_formats): raise GKException("length of registered format doesn't match the input of {}".format(obj.name)) if all([fmt == inp for fmt, inp in zip(formats, inp_formats)]): return raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
def _check_format(obj): inp_formats = [inp['format'] for inp in obj.inputs] if all([fmt == inp_formats[0] for fmt in inp_formats[1:]]): return raise GKException( "[check_all_formats_same] unmatched formats ({}) for op {}". format(','.join(inp_formats), obj.name))
def _check(self): type_0 = self.inputs[0]['data_type'] type_1 = self.inputs[1]['data_type'] if type_0 != "float16" or type_1 != "float16": raise GKException("inputs type should be float16, but got {} and {}".format(type_0, type_1)) groups = self.attrs['groups'] group = self.attrs['group'] if groups != 1 or group != 1: raise GKException("groups and group should be both 1, but got {} and {}.".format(groups, group)) dilation = self.attrs['dilation'] check_nd(dilation, 4) if dilation != [1, 1, 1, 1]: raise GKException("dilation should be all 1, but got {}".format(dilation)) pad_list = self.attrs['pad_list'] pad_mode = self.attrs['pad_mode'] check_nd(pad_list, 4) self.has_pad = conv_had_pad(pad_list, pad_mode) shape_0 = self.inputs[0]['shape'] shape_1 = self.inputs[1]['shape'] stride = self.attrs['stride'] check_nd(shape_0, 4) check_nd(shape_1, 4) check_nd(stride, 4) n0, h0, w0, c0 = shape_0 n1, h1, w1, c1 = shape_1 if n0 < N0_CHANNEL_ALIGN: raise GKException("N({}) channel of first input should >= {}".format(n0, N0_CHANNEL_ALIGN)) if n1 < N1_CHANNEL_ALIGN: raise GKException("N({}) channel of second input should >= {}".format(n1, N1_CHANNEL_ALIGN)) if c0 != c1 or c0 < C_CHANNEL_ALIGN: raise GKException("C channel of inputs({}, {}) should be same and >= {}".format(c0, c1, C_CHANNEL_ALIGN)) # n0 pad n0 = ((n0 + N0_CHANNEL_ALIGN - 1) // N0_CHANNEL_ALIGN) * N0_CHANNEL_ALIGN # h0, w0 pad if self.has_pad: h0 = h0 + pad_list[0] + pad_list[1] w0 = w0 + pad_list[2] + pad_list[3] # c0, c1 pad c0 = ((c0 + C_CHANNEL_ALIGN - 1) // C_CHANNEL_ALIGN) * C_CHANNEL_ALIGN c1 = c0 # n1 pad n1 = ((n1 + N1_CHANNEL_ALIGN - 1) // N1_CHANNEL_ALIGN) * N1_CHANNEL_ALIGN # check if can optimize to matmul self.m, self.n, self.k = n0 * h0 * w0, n1, c1 self.can_optimize_to_matmul = self._optimize_to_matmul() out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1 if not self.can_optimize_to_matmul and n0 * out_h * out_w % OUT_NHW_ALIGN != 0: raise GKException("N({}) * H({}) * W({}) of Conv2d output should be multiplies of {}" .format(n0, out_h, out_w, OUT_NHW_ALIGN)) self.shape_0_pad = [n0, h0, w0, c0] self.shape_1_pad = [n1, h1, w1, c1]
def _check_attr(obj): for a in args: if a not in obj.attrs: raise GKException("attr '{}' does not exist.".format(a))
def _check(self): if not self.attrs.get('grad_x', True) and not self.attrs.get( 'grad_y', True): raise GKException("both grad_x and grad_y are False.") return super()._check()