示例#1
0
 def _exchange_dims_in_attrs(self):
   self._attr_value_mem[self.AttrName.ORDER][transformed_axis(
       src="NCHW", dst="NHWC", ndim=self._input_ndim,
       dim=self._dim0)] = transformed_axis(
           src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=self._dim1)
   self._attr_value_mem[self.AttrName.ORDER][transformed_axis(
       src="NCHW", dst="NHWC", ndim=self._input_ndim,
       dim=self._dim1)] = transformed_axis(
           src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=self._dim0)
示例#2
0
 def dim(self, value):
     self._attr_value_mem[self.AttrName.AXIS][:] = [
         transformed_axis(src="NCHW",
                          dst="NHWC",
                          ndim=self._input_ndim,
                          dim=value)
     ]
示例#3
0
    def start(self, start):
        if self._input_ndim != 4:
            begin_mask = 0
            for dim, pos in enumerate(start):
                if pos == 0:
                    begin_mask |= 1 << dim
            self.set_attr(self.AttrName.BEGIN_MASK, begin_mask)
            self.set_attr(self.AttrName.BEGIN, start)

        else:
            begin = [0] * 4
            begin_mask = 0
            for dim, pos in enumerate(start):
                new_dim = transformed_axis(src="NCHW",
                                           dst="NHWC",
                                           ndim=self._input_ndim,
                                           dim=dim)
                begin[new_dim] = pos

            for dim, pos in enumerate(begin):
                if pos == 0:
                    begin_mask |= 1 << dim

            self.set_attr(self.AttrName.BEGIN_MASK, begin_mask)
            self.set_attr(self.AttrName.BEGIN, begin)
示例#4
0
def squeeze(xgraph: XGraph, node: Node,
            quant_config: NndctQuantInfo) -> NoReturn:
    if node.in_tensors[0].ndim == 4 and len(
            node.node_attr(node.op.AttrName.DIMS)) == 1:
        attrs: Dict[str, Any] = {}
        attrs["order"] = [0, 3, 1, 2]

        # resume dimension to NCHW
        input_ops: Dict[str, List[Op]] = {}
        input_list = []
        for input in node.in_nodes:
            input_op = xgraph.get_op_by_name(input)
            input_list.append(input_op)
        input_ops["input"] = input_list
        xgraph.create_fixed_normal_op(node.name + "_i0",
                                      "transpose",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)

        attrs: Dict[str, Any] = {}
        dim = node.node_attr(node.op.AttrName.DIMS)[0]
        dim = transformed_axis("NHWC", "NCHW", ndim=4, dim=dim)
        attrs["axis"] = [dim]
        input_ops: Dict[str, List[Op]] = {}
        input_ops["input"] = [xgraph.get_op_by_name(node.name + "_i0")]
        xgraph.create_fixed_normal_op(node.name,
                                      "squeeze",
                                      quant_config,
                                      attrs=attrs,
                                      input_ops=input_ops)
    else:
        to_xir("squeeze")(xgraph, node, quant_config)
示例#5
0
 def dims(self, value):
   dims = []
   for dim in value:
     dims.append(
         transformed_axis(
             src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim))
   self._attr_value_mem[self.AttrName.ORDER][:] = dims[:]
示例#6
0
 def dim(self):
   dims = []  
   for dim in self.get_attr(self.AttrName.DIMS):
     dims.append(
         transformed_axis(
             src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim))
   return tuple(dims)
示例#7
0
 def dims(self):
   dims = []
   for dim in self._attr_value_mem[self.AttrName.ORDER]:
     dims.append(
         transformed_axis(
             src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim))
   return dims
示例#8
0
 def step(self):
   if self._input_ndim != 4:
     return self.get_attr(self.AttrName.STRIDES)
   else:
     strides = [1] * 4
     for dim, step in enumerate(self.get_attr(self.AttrName.STRIDES)):
       new_dim = transformed_axis(
           src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim)
       strides[new_dim] = step
     return strides
示例#9
0
 def end(self):
   if self._input_ndim != 4:
     return self.get_attr(self.AttrName.END)
   else:
     end = [NNDCT_CONSTANT.INT_MAX] * 4
     for dim, pos in enumerate(self.get_attr(self.AttrName.END)):
       new_dim = transformed_axis(
           src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim)
       end[new_dim] = pos
     return end
示例#10
0
 def start(self):
   if self._input_ndim != 4:
     return self.get_attr(self.AttrName.BEGIN)
   else:
     begin = [0] * 4
     for dim, pos in enumerate(self.get_attr(self.AttrName.BEGIN)):
       new_dim = transformed_axis(
           src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=dim)
       begin[new_dim] = pos
     return begin
示例#11
0
 def dim(self, value):
   if isinstance(value, (tuple, list)):
     value = list(value)
   else:
     value = [value]
   dims = []
   for dim in value:
     dims.append(
         transformed_axis(
             src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim))
   
   self.set_attr(self.AttrName.DIMS, [dims[:]])
示例#12
0
  def end(self, end):
    if self._input_ndim != 4:
      end_mask = 0
      for dim, pos in enumerate(end):
        if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX:
          end_mask |= 1 << dim
      self.set_attr(self.AttrName.END_MASK, end_mask)
      self.set_attr(self.AttrName.END, end)
    else:
      new_end = [NNDCT_CONSTANT.INT_MAX] * 4
      end_mask = 0
      for dim, pos in enumerate(end):
        new_dim = transformed_axis(
            src="NCHW", dst="NHWC", ndim=self._input_ndim, dim=dim)
        new_end[new_dim] = pos

      for dim, pos in enumerate(new_end):
        if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX:
          end_mask |= 1 << dim

      self.set_attr(self.AttrName.END_MASK, end_mask)
      self.set_attr(self.AttrName.END, new_end)
示例#13
0
 def dim(self):
     return transformed_axis(
         src="NHWC",
         dst="NCHW",
         ndim=self._input_ndim,
         dim=self._attr_value_mem[self.AttrName.AXIS][0])
示例#14
0
 def dim1(self):
     return self._attr_value_mem[self.AttrName.ORDER][transformed_axis(
         src="NHWC", dst="NCHW", ndim=self._input_ndim, dim=self._dim0)]