Пример #1
0
    def __init__(self, block, drop_connect_rate):
        super().__init__()
        self._bn_mom = block._bn_mom
        self._bn_eps = block._bn_eps
        self.id_skip = block.id_skip  # skip connection and drop connect
        self.drop_connect_rate = drop_connect_rate
        self.input_filters = block._block_args.input_filters
        self.output_filters = block._block_args.output_filters
        self.stride = block._block_args.stride

        # Expansion phase
        self._expand_conv = block._expand_conv
        self._bn0 = block._bn0

        # Depthwise convolution phase
        self._depthwise_conv = block._depthwise_conv
        self._bn1 = block._bn1

        # Squeeze and Excitation layer, if desired
        self._se_reduce = block._se_reduce
        self._se_expand = block._se_expand

        # Output phase
        self._project_conv = block._project_conv
        self._bn2 = block._bn2
        self._swish = Swish()
Пример #2
0
 def set_swish(self, memory_efficient=True):
     """Sets swish function as memory efficient (for training) or standard (for export)"""
     self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
     for block in self._blocks:
         block.set_swish(memory_efficient)
Пример #3
0
 def set_swish(self, memory_efficient: bool = True) -> NoReturn:
     """Sets swish function as memory efficient (training) or standard."""
     self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
     for block in self._blocks:
         block.set_swish(memory_efficient)
Пример #4
0
import torch