コード例 #1
0
 def __init__(self, inchannel, outchannel, innerchannel, stride=1):
     """Init PruneBasicBlock."""
     super(PruneBasicBlock, self).__init__()
     conv_block = PruneBasicConv(inchannel, outchannel, innerchannel, stride)
     shortcut = ShortCut(inchannel, outchannel, self.expansion, stride)
     self.block = Add(conv_block, shortcut)
     self.relu3 = ops.Relu()
コード例 #2
0
ファイル: blocks.py プロジェクト: vineetrao25/vega
    def __init__(self,
                 inchannel,
                 outchannel,
                 groups=1,
                 base_width=64,
                 stride=1,
                 norm_layer={"norm_type": 'BN'},
                 Conv2d='Conv2d'):
        """Create BottleneckBlock layers.

        :param inchannel: input channel.
        :type inchannel: int
        :param outchannel: output channel.
        :type outchannel: int
        :param stride: the number to jump, default 1
        :type stride: int
        """
        super(BottleneckBlock, self).__init__()
        bottle_conv = BottleConv(inchannel=inchannel,
                                 outchannel=outchannel,
                                 expansion=self.expansion,
                                 stride=stride,
                                 groups=groups,
                                 base_width=base_width,
                                 norm_layer=norm_layer,
                                 Conv2d=Conv2d)
        shortcut = ShortCut(inchannel=inchannel,
                            outchannel=outchannel,
                            expansion=self.expansion,
                            stride=stride,
                            norm_layer=norm_layer)
        self.block = Add(bottle_conv, shortcut)
        self.relu = ops.Relu()
コード例 #3
0
    def __init__(self, in_channel, out_channel, upscale, rgb_mean, blocks,
                 candidates, cib_range, method, code, block_range):
        """Construct the MtMSR class.

        :param net_desc: config of the searched structure
        """
        super(MtMSR, self).__init__()
        logging.info("start init MTMSR")
        current_channel = in_channel
        layers = list()
        for i, block_name in enumerate(blocks):
            if isinstance(block_name, list):
                layers.append(ChannelIncreaseBlock(block_name,
                                                   current_channel))
                current_channel *= len(block_name)
            else:
                if block_name == "res2":
                    layers.append(
                        ResidualBlock(kernel_size=2,
                                      base_channel=current_channel))
                elif block_name == "res3":
                    layers.append(
                        ResidualBlock(kernel_size=3,
                                      base_channel=current_channel))
        layers.extend([
            conv(current_channel, out_channel * upscale**2),
            ops.PixelShuffle(upscale)
        ])
        initialize_weights(layers[-2], 0.1)
        self.sub_mean = ops.MeanShift(1.0, rgb_mean)
        body = Sequential(*layers)
        upsample = ops.InterpolateScale(scale_factor=upscale)
        self.add = Add(body, upsample)
        self.head = ops.MeanShift(1.0, rgb_mean, sign=1)
コード例 #4
0
    def __init__(self,
                 inchannel,
                 outchannel,
                 groups=1,
                 base_width=64,
                 stride=1):
        """Create BasicBlock layers.

        :param inchannel: input channel.
        :type inchannel: int
        :param outchannel: output channel.
        :type outchannel: int
        :param stride: the number to jump, default 1
        :type stride: int
        """
        super(BasicBlock, self).__init__()
        base_conv = BasicConv(inchannel=inchannel,
                              outchannel=outchannel,
                              stride=stride,
                              groups=groups,
                              base_width=base_width)
        shortcut = ShortCut(inchannel=inchannel,
                            outchannel=outchannel,
                            expansion=self.expansion,
                            stride=stride)
        self.block = Add(base_conv, shortcut)
        self.relu = ops.Relu()