def _exchange_dims_in_attrs(self): self._attr_value_mem[self.AttrName.ORDER][transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=self._dim0)] = transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=self._dim1) self._attr_value_mem[self.AttrName.ORDER][transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=self._dim1)] = transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=self._dim0)
def dim(self, value): self._attr_value_mem[self.AttrName.AXIS][:] = [ transformed_axis(src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=value) ]
def start(self, start): if self._input_ndim != 4: begin_mask = 0 for dim, pos in enumerate(start): if pos == 0: begin_mask |= 1 << dim self.set_attr(self.AttrName.BEGIN_MASK, begin_mask) self.set_attr(self.AttrName.BEGIN, start) else: begin = [0] * 4 begin_mask = 0 for dim, pos in enumerate(start): new_dim = transformed_axis(src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim) begin[new_dim] = pos for dim, pos in enumerate(begin): if pos == 0: begin_mask |= 1 << dim self.set_attr(self.AttrName.BEGIN_MASK, begin_mask) self.set_attr(self.AttrName.BEGIN, begin)
def squeeze(xgraph: XGraph, node: Node, quant_config: NndctQuantInfo) -> NoReturn: if node.in_tensors[0].ndim == 4 and len( node.node_attr(node.op.AttrName.DIMS)) == 1: attrs: Dict[str, Any] = {} attrs["order"] = [0, 3, 1, 2] # resume dimension to NCHW input_ops: Dict[str, List[Op]] = {} input_list = [] for input in node.in_nodes: input_op = xgraph.get_op_by_name(input) input_list.append(input_op) input_ops["input"] = input_list xgraph.create_fixed_normal_op(node.name + "_i0", "transpose", quant_config, attrs=attrs, input_ops=input_ops) attrs: Dict[str, Any] = {} dim = node.node_attr(node.op.AttrName.DIMS)[0] dim = transformed_axis("NHWC", "NCHW", ndim=4, dim=dim) attrs["axis"] = [dim] input_ops: Dict[str, List[Op]] = {} input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")] xgraph.create_fixed_normal_op(node.name, "squeeze", quant_config, attrs=attrs, input_ops=input_ops) else: to_xir("squeeze")(xgraph, node, quant_config)
def dims(self, value): dims = [] for dim in value: dims.append( transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim)) self._attr_value_mem[self.AttrName.ORDER][:] = dims[:]
def dim(self): dims = [] for dim in self.get_attr(self.AttrName.DIMS): dims.append( transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim)) return tuple(dims)
def dims(self): dims = [] for dim in self._attr_value_mem[self.AttrName.ORDER]: dims.append( transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim)) return dims
def step(self): if self._input_ndim != 4: return self.get_attr(self.AttrName.STRIDES) else: strides = [1] * 4 for dim, step in enumerate(self.get_attr(self.AttrName.STRIDES)): new_dim = transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim) strides[new_dim] = step return strides
def end(self): if self._input_ndim != 4: return self.get_attr(self.AttrName.END) else: end = [NNDCT_CONSTANT.INT_MAX] * 4 for dim, pos in enumerate(self.get_attr(self.AttrName.END)): new_dim = transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim) end[new_dim] = pos return end
def start(self): if self._input_ndim != 4: return self.get_attr(self.AttrName.BEGIN) else: begin = [0] * 4 for dim, pos in enumerate(self.get_attr(self.AttrName.BEGIN)): new_dim = transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim) begin[new_dim] = pos return begin
def dim(self, value): if isinstance(value, (tuple, list)): value = list(value) else: value = [value] dims = [] for dim in value: dims.append( transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim)) self.set_attr(self.AttrName.DIMS, [dims[:]])
def end(self, end): if self._input_ndim != 4: end_mask = 0 for dim, pos in enumerate(end): if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX: end_mask |= 1 << dim self.set_attr(self.AttrName.END_MASK, end_mask) self.set_attr(self.AttrName.END, end) else: new_end = [NNDCT_CONSTANT.INT_MAX] * 4 end_mask = 0 for dim, pos in enumerate(end): new_dim = transformed_axis( src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim) new_end[new_dim] = pos for dim, pos in enumerate(new_end): if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX: end_mask |= 1 << dim self.set_attr(self.AttrName.END_MASK, end_mask) self.set_attr(self.AttrName.END, new_end)
def dim(self): return transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=self._attr_value_mem[self.AttrName.AXIS][0])
def dim1(self): return self._attr_value_mem[self.AttrName.ORDER][transformed_axis( src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=self._dim0)]