예제 #1
0
 def max_transform(cls, nts):
   if nts.dtype not in cls.allowedTypes:
     return None
   s = Shape()
   result = TypeShape(nts.dtype, s)
   for _dim in nts.shape.dim:
     if _dim.name == DimNames.BATCH:
       s.dim.append(Shape.Dim(_dim.name, _dim.size))
     elif _dim.name == DimNames.UNITS:
       s.dim.append(Shape.Dim(_dim.name, int(math.ceil(_dim.size * cls.__max_f))))
     else:
       return None
   return result
예제 #2
0
 def min_transform(cls, nts: TypeShape):
     if nts.dtype not in cls.allowedTypes:
         return None
     s = Shape((DimNames.UNITS, 1))
     result = TypeShape(nts.dtype, s)
     idx = [
         i for i, d in enumerate(nts.shape.dim) if d.name == DimNames.BATCH
     ]
     if len(idx) > 0:
         idx = idx[0]
         b = nts.shape.dim[idx].size
         if idx + 1 < len(nts.shape.dim) // 2:
             s.dim.insert(0, Shape.Dim(DimNames.BATCH, b))
         else:
             s.dim.append((Shape.Dim(DimNames.BATCH, b)))
     return result
예제 #3
0
 def min_transform(cls, nts):
     if nts.dtype not in cls.allowedTypes:
         return None
     s = Shape()
     result = TypeShape(nts.dtype, s)
     for _dim in nts.shape.dim:
         if _dim.name == DimNames.BATCH:
             s.dim.append(Shape.Dim(_dim.name, _dim.size))
         elif _dim.name == DimNames.CHANNEL:
             s.dim.append(
                 Shape.Dim(_dim.name,
                           int(math.floor(_dim.size * cls.__min_f_c))))
         elif _dim.name == DimNames.WIDTH or \
             _dim.name == DimNames.HEIGHT:
             s.dim.append(
                 Shape.Dim(_dim.name,
                           int(math.floor(_dim.size * cls.__min_f_hw))))
         else:
             return None
     return result