Esempio n. 1
0
 def _exchange_dims_in_attrs(self):
     dim0_t = transformed_axis(src=DataFormat.channel_first,
                               dst=DataFormat.channel_first,
                               ndim=self._input_ndim,
                               dim=self._dim0)
     dim1_t = transformed_axis(src=DataFormat.channel_first,
                               dst=DataFormat.channel_first,
                               ndim=self._input_ndim,
                               dim=self._dim1)
     self._attr_value_mem[self.AttrName.ORDER][dim0_t] = dim1_t
     self._attr_value_mem[self.AttrName.ORDER][dim1_t] = dim0_t
Esempio n. 2
0
    def start(self, start):
        if self._input_ndim < 4:
            begin_mask = 0
            for dim, pos in enumerate(start):
                if pos == 0:
                    begin_mask |= 1 << dim
            self.set_attr(self.AttrName.BEGIN_MASK, begin_mask)
            self.set_attr(self.AttrName.BEGIN, start)

        else:
            begin = [0] * self._input_ndim
            begin_mask = 0
            for dim, pos in enumerate(start):
                new_dim = transformed_axis(src=DataFormat.channel_first,
                                           dst=DataFormat.channel_first,
                                           ndim=self._input_ndim,
                                           dim=dim)
                begin[new_dim] = pos

            for dim, pos in enumerate(begin):
                if pos == 0:
                    begin_mask |= 1 << dim

            self.set_attr(self.AttrName.BEGIN_MASK, begin_mask)
            self.set_attr(self.AttrName.BEGIN, begin)
Esempio n. 3
0
 def dim(self, value):
     self._attr_value_mem[self.AttrName.AXIS][:] = [
         transformed_axis(src=DataFormat.channel_first,
                          dst=DataFormat.channel_first,
                          ndim=self._input_ndim,
                          dim=value)
     ]
Esempio n. 4
0
 def dims(self, value):
     dims = []
     for dim in value:
         dims.append(
             transformed_axis(src=DataFormat.channel_first,
                              dst=DataFormat.channel_first,
                              ndim=self._input_ndim,
                              dim=dim))
     self._attr_value_mem[self.AttrName.ORDER][:] = dims[:]
Esempio n. 5
0
 def dim(self):
     dims = []
     for dim in self.get_attr(self.AttrName.DIMS):
         dims.append(
             transformed_axis(src=DataFormat.channel_first,
                              dst=DataFormat.channel_first,
                              ndim=self._input_ndim,
                              dim=dim))
     return tuple(dims)
Esempio n. 6
0
 def step(self):
     if self._input_ndim < 4:
         return self.get_attr(self.AttrName.STRIDES)
     else:
         strides = [1] * self._input_ndim
         for dim, step in enumerate(self.get_attr(self.AttrName.STRIDES)):
             new_dim = transformed_axis(src=DataFormat.channel_first,
                                        dst=DataFormat.channel_first,
                                        ndim=self._input_ndim,
                                        dim=dim)
             strides[new_dim] = step
         return strides
Esempio n. 7
0
 def end(self):
     if self._input_ndim < 4:
         return self.get_attr(self.AttrName.END)
     else:
         end = [NNDCT_CONSTANT.INT_MAX] * self._input_ndim
         for dim, pos in enumerate(self.get_attr(self.AttrName.END)):
             new_dim = transformed_axis(src=DataFormat.channel_first,
                                        dst=DataFormat.channel_first,
                                        ndim=self._input_ndim,
                                        dim=dim)
             end[new_dim] = pos
         return end
Esempio n. 8
0
 def start(self):
     if self._input_ndim < 4:
         return self.get_attr(self.AttrName.BEGIN)
     else:
         begin = [0] * self._input_ndim
         for dim, pos in enumerate(self.get_attr(self.AttrName.BEGIN)):
             new_dim = transformed_axis(src=DataFormat.channel_first,
                                        dst=DataFormat.channel_first,
                                        ndim=self._input_ndim,
                                        dim=dim)
             begin[new_dim] = pos
         return begin
Esempio n. 9
0
    def dim(self, value):
        if isinstance(value, (tuple, list)):
            value = list(value)
        else:
            value = [value]
        dims = []
        for dim in value:
            dims.append(
                transformed_axis(src=DataFormat.channel_first,
                                 dst=DataFormat.channel_first,
                                 ndim=self._input_ndim,
                                 dim=dim))

        self.set_attr(self.AttrName.DIMS, dims)
Esempio n. 10
0
    def end(self, end):
        if self._input_ndim < 4:
            end_mask = 0
            for dim, pos in enumerate(end):
                if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX:
                    end_mask |= 1 << dim
            self.set_attr(self.AttrName.END_MASK, end_mask)
            self.set_attr(self.AttrName.END, end)
        else:
            new_end = [NNDCT_CONSTANT.INT_MAX] * self._input_ndim
            end_mask = 0
            for dim, pos in enumerate(end):
                new_dim = transformed_axis(src=DataFormat.channel_first,
                                           dst=DataFormat.channel_first,
                                           ndim=self._input_ndim,
                                           dim=dim)
                new_end[new_dim] = pos

            for dim, pos in enumerate(new_end):
                if isinstance(pos, int) and pos >= NNDCT_CONSTANT.INT_MAX:
                    end_mask |= 1 << dim

            self.set_attr(self.AttrName.END_MASK, end_mask)
            self.set_attr(self.AttrName.END, new_end)
Esempio n. 11
0
 def dim(self):
     return transformed_axis(
         src=DataFormat.channel_first,
         dst=DataFormat.channel_first,
         ndim=self._input_ndim,
         dim=self._attr_value_mem[self.AttrName.AXIS][0])