def _exchange_dims_in_attrs(self): dim0_t = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=self._dim0) dim1_t = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=self._dim1) self._attr_value_mem[self.AttrName.ORDER][dim0_t] = dim1_t self._attr_value_mem[self.AttrName.ORDER][dim1_t] = dim0_t
def start(self, start): if self._input_ndim < 4: begin_mask = 0 for dim, pos in enumerate(start): if pos == 0: begin_mask |= 1 << dim self.set_attr(self.AttrName.BEGIN_MASK, begin_mask) self.set_attr(self.AttrName.BEGIN, start) else: begin = [0] * self._input_ndim begin_mask = 0 for dim, pos in enumerate(start): new_dim = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim) begin[new_dim] = pos for dim, pos in enumerate(begin): if pos == 0: begin_mask |= 1 << dim self.set_attr(self.AttrName.BEGIN_MASK, begin_mask) self.set_attr(self.AttrName.BEGIN, begin)
def dim(self, value): self._attr_value_mem[self.AttrName.AXIS][:] = [ transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=value) ]
def dims(self, value): dims = [] for dim in value: dims.append( transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim)) self._attr_value_mem[self.AttrName.ORDER][:] = dims[:]
def dim(self): dims = [] for dim in self.get_attr(self.AttrName.DIMS): dims.append( transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim)) return tuple(dims)
def step(self): if self._input_ndim < 4: return self.get_attr(self.AttrName.STRIDES) else: strides = [1] * self._input_ndim for dim, step in enumerate(self.get_attr(self.AttrName.STRIDES)): new_dim = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim) strides[new_dim] = step return strides
def end(self): if self._input_ndim < 4: return self.get_attr(self.AttrName.END) else: end = [NNDCT_CONSTANT.INT_MAX] * self._input_ndim for dim, pos in enumerate(self.get_attr(self.AttrName.END)): new_dim = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim) end[new_dim] = pos return end
def start(self): if self._input_ndim < 4: return self.get_attr(self.AttrName.BEGIN) else: begin = [0] * self._input_ndim for dim, pos in enumerate(self.get_attr(self.AttrName.BEGIN)): new_dim = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim) begin[new_dim] = pos return begin
def dim(self, value): if isinstance(value, (tuple, list)): value = list(value) else: value = [value] dims = [] for dim in value: dims.append( transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim)) self.set_attr(self.AttrName.DIMS, dims)
def end(self, end): if self._input_ndim < 4: end_mask = 0 for dim, pos in enumerate(end): if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX: end_mask |= 1 << dim self.set_attr(self.AttrName.END_MASK, end_mask) self.set_attr(self.AttrName.END, end) else: new_end = [NNDCT_CONSTANT.INT_MAX] * self._input_ndim end_mask = 0 for dim, pos in enumerate(end): new_dim = transformed_axis(src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=dim) new_end[new_dim] = pos for dim, pos in enumerate(new_end): if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX: end_mask |= 1 << dim self.set_attr(self.AttrName.END_MASK, end_mask) self.set_attr(self.AttrName.END, new_end)
def dim(self): return transformed_axis( src=DataFormat.channel_first, dst=DataFormat.channel_first, ndim=self._input_ndim, dim=self._attr_value_mem[self.AttrName.AXIS][0])