Ejemplo n.º 1
0
 def __init__(self, C, C_out, stride, affine, act='relu'):
     super(ResNetBlockSplit, self).__init__()
     self.act = act
     self.op_1_1 = ConvBNReLU(C,
                              C_out,
                              3,
                              stride,
                              1,
                              affine=affine,
                              relu=False)
     self.op_1_2 = ConvBNReLU(C,
                              C_out,
                              3,
                              stride,
                              1,
                              affine=affine,
                              relu=False)
     self.op_2_1 = ConvBNReLU(C_out,
                              C_out,
                              3,
                              1,
                              1,
                              affine=affine,
                              relu=False)
     self.op_2_2 = ConvBNReLU(C_out,
                              C_out,
                              3,
                              1,
                              1,
                              affine=affine,
                              relu=False)
     self.skip_op = Identity() if stride == 1 else ConvBNReLU(
         C, C_out, 1, stride, 0, affine=affine, relu=False)
Ejemplo n.º 2
0
    def __init__(self, C, C_out, stride, affine, conv_ds=False):
        super(SkipConnectV2, self).__init__()
        self.stride = stride
        self.conv_ds = conv_ds
        self.expansion = C_out // C
        if stride == 2:
            if self.conv_ds:
                # support arbitary chs
                self.op1 = nn.AvgPool2d(2)
                self.op2 = XNORGroupConv(
                    C,
                    C_out,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    affine=affine,
                    shortcut=False,
                )
            else:
                assert C_out == 2 * C or C_out == C
                self.op1 = nn.AvgPool2d(2)
                if self.expansion == 2:
                    self.op2 = nn.AvgPool2d(2)

        if stride == 1:
            self.op = Identity()
Ejemplo n.º 3
0
    def __init__(self,
                 C,
                 C_out,
                 stride,
                 affine,
                 kernel_size=3,
                 act="relu",
                 downsample="conv"):
        super(ResNetBlock, self).__init__()
        self.stride = stride
        padding = int((kernel_size - 1) / 2)
        self.act = act
        self.activation = F.relu
        if self.act == 'hardtanh':
            self.activation = F.hardtanh
        elif self.act == 'sigmoid':
            self.activation = F.sigmoid

        self.op_1 = ConvBNReLU(C,
                               C_out,
                               kernel_size,
                               stride,
                               padding,
                               affine=affine,
                               relu=False)
        self.op_2 = ConvBNReLU(C_out,
                               C_out,
                               kernel_size,
                               1,
                               padding,
                               affine=affine,
                               relu=False)
        if downsample == "conv":
            self.skip_op = Identity() if stride == 1 else ConvBNReLU(
                C, C_out, 1, stride, 0, affine=affine, relu=False)
        elif downsample == "avgpool":
            self.skip_op = Identity() if stride == 1 else ResNetDownSample(
                stride)
Ejemplo n.º 4
0
 def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, expansion=1, group=1):
     super(XNORDilConv, self).__init__()
     if expansion is not 1:
         raise Exception("binary block has inner connection, so donot support expansion")
     if stride == 1:
         self.shortcut = Identity()
     else:
         self.shortcut = nn.AvgPool2d(kernel_size=2) # maybe avgpool2d will cause grad loss
     self.op = nn.Sequential(
         nn.BatchNorm2d(C_in*expansion, affine=affine),
         XNORConv2d(C_in, C_in*expansion, kernel_size=kernel_size, stride=stride,
                   dilation=dilation , padding=padding, groups=group),
         nn.ReLU(inplace=False),
     )
Ejemplo n.º 5
0
 def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, expansion=1, group=1, shortcut=True): # FIXME: default using the shortcut for binary op may not be proper
     super(XNORGroupConv, self).__init__()
     if shortcut: # only use shortcut when C-in and C-out are the same
         self.use_shortcut = True
     else:
         self.use_shortcut = False
     if self.use_shortcut:
         if expansion is not 1:
             raise Exception("when enabling shortcut, so donot support expansion")
         if C_in != C_out:
             raise Exception("when enabling shortcut, only surpport c-in == c-out")
         if stride == 1:
             self.shortcut = Identity()
         else:
             self.shortcut = nn.AvgPool2d(kernel_size=2)
     else:
         pass
     self.op = nn.Sequential(
         nn.BatchNorm2d(C_in*expansion, affine=affine),
         XNORConv2d(C_in, C_out*expansion, kernel_size=kernel_size, stride=stride,
                   padding=padding, groups=group),
         nn.ReLU(inplace=False),
     )
Ejemplo n.º 6
0
 def __init__(self, C_in, C_out, kernel_size, stride, padding):
     super(ResSepConv, self).__init__()
     self.conv = SepConv(C_in, C_out, kernel_size, stride, padding)
     self.res = Identity() if stride == 1 else FactorizedReduce(
         C_in, C_out, stride)
Ejemplo n.º 7
0
    def __init__(
        self,
        C,
        C_out,
        stride,
        affine,
        kernel_size=3,
        block="bireal",
        act=BinaryActivation,
        downsample="conv",
        fp32_act=False,
    ):
        super(BinaryResNetBlock, self).__init__()
        self.stride = stride
        padding = int((kernel_size - 1) / 2)
        self.activation = act

        if block == "bireal":
            self.op_1 = BinaryConvBNReLU(
                C,
                C_out,
                kernel_size,
                stride,
                padding,
                affine=affine,
                bias=None,
                activation=self.activation,
            )
            self.op_2 = BinaryConvBNReLU(
                C_out,
                C_out,
                kernel_size,
                1,
                padding,
                affine=affine,
                bias=None,
                activation=self.activation,
            )
        elif block == "xnor":
            self.op_1 = XNORConvBNReLU(
                C,
                C_out,
                kernel_size,
                stride,
                padding,
                affine=affine,
                groups=1,
                dropout_ratio=0,
                fp32_act=fp32_act,
            )
            self.op_2 = XNORConvBNReLU(
                C_out,
                C_out,
                kernel_size,
                1,
                padding,
                affine=affine,
                groups=1,
                dropout_ratio=0,
                fp32_act=fp32_act,
            )
        elif block == "dorefa":
            self.op_1 = DorefaConvBNReLU(C,
                                         C_out,
                                         kernel_size,
                                         stride,
                                         padding,
                                         affine=affine,
                                         groups=1)
            self.op_2 = DorefaConvBNReLU(C_out,
                                         C_out,
                                         kernel_size,
                                         1,
                                         padding,
                                         affine=affine,
                                         groups=1)
        if downsample == "conv":
            self.skip_op = (Identity() if stride == 1 else ConvBNReLU(
                C, C_out, 1, stride, 0, affine=affine))
        elif downsample == "avgpool":
            self.skip_op = Identity() if stride == 1 else ResNetDownSample(
                stride)
Ejemplo n.º 8
0
    def forward_one_step(self, context=None, inputs=None):
        return self.op.forward_one_step(context, inputs)

register_primitive("xnor_conv_3x3",
                   lambda C, C_out, stride, affine: XNORGroupConv(
                       C, C_out, 3, stride, 1, affine=affine, group=1),
)

register_primitive("xnor_conv_3x3_noskip",
                   lambda C, C_out, stride, affine: XNORGroupConv(
                       C, C_out, 3, stride, 1, affine=affine, group=1, shortcut=False),
)
register_primitive("cond_xnor_conv_3x3_noskip",
                   lambda C, C_out, stride, affine: XNORGroupConv(
                       C, C_out, 3, stride, 1, affine=affine, group=1, shortcut=False)\
                   if stride > 1 or C != C_out else Identity()
)

register_primitive("xnor_conv_5x5",
                   lambda C, C_out, stride, affine: XNORGroupConv(
                       C, C_out, 5, stride, 2, affine=affine, group=1),
)

register_primitive("xnor_conv_5x5_noskip",
                   lambda C, C_out, stride, affine: XNORGroupConv(
                       C, C_out, 5, stride, 2, affine=affine, group=1, shortcut=False),
)

register_primitive("xnor_conv_1x1",
                   lambda C, C_out, stride, affine: XNORGroupConv(
                       C, C_out, 1, stride, 0, affine=affine, group=1),
Ejemplo n.º 9
0
 def __init__(
     self,
     C_in,
     C_out,
     stride,
     affine,
     relu=True,
     kernel_size=3,
     downsample="conv",
     # --- the binary cfgs ---
     binary_cfgs={},
 ):
     super(BinaryResNetBlock, self).__init__()
     (
         self.C_in,
         self.C_out,
         self.stride,
         self.affine,
         self.relu,
         self.kernel_size,
         self.downsample,
         self.binary_cfgs,
     ) = (
         C_in,
         C_out,
         stride,
         affine,
         relu,
         kernel_size,
         downsample,
         binary_cfgs,
     )
     padding = int((kernel_size - 1) / 2)
     self.op_1 = BinaryConvBNReLU(
         C_in,
         C_out,
         kernel_size,
         stride,
         padding,
         affine,
         relu,
         **binary_cfgs,
     )
     self.op_2 = BinaryConvBNReLU(
         C_out,
         C_out,
         kernel_size,
         1,  # the 2nd block keep the same dim
         padding,
         affine,
         relu,
         **binary_cfgs,
     )
     # define the skip_op
     assert downsample in ["conv", "avgpool", "binary_conv"]
     if downsample == "conv":
         self.skip_op = (Identity() if stride == 1 else ConvBNReLU(
             C_in, C_out, 1, stride, 0, affine=affine))
     elif downsample == "avgpool":
         self.skip_op = Identity() if stride == 1 else ResNetDownSample(
             stride)
     if downsample == "binary_conv":
         self.skip_op = (Identity() if stride == 1 else BinaryConvBNReLU(
             C_in, C_out, 3, stride, 1, affine=affine, **binary_cfgs))
Ejemplo n.º 10
0
    "shortcut_op_type": "simple",
    "reduction_op_type": "factorized",
    # "layer_order": "conv_bn_relu",
    "layer_order": "bn_conv_relu",
    "binary_conv_cfgs": {
        "bi_w_scale": 1,
        "bi_act_method": 0,
        "bias": False,
    },
}

# plain shortcut when stride == 1, and stride=2 binary conv with no shortcut
# for cell-wise shortcut
register_primitive(
    "xnor_conv_3x3_skip_connect",
    lambda C, C_out, stride, affine: Identity() if stride == 1 else
    BinaryConvBNReLU(C, C_out, 3, stride, 1, affine=affine, **binary_cfgs),
)


# ---- binary NIN ----
class NIN(nn.Module):
    def __init__(self, C, C_out, stride, affine, binary_cfgs=binary_cfgs):

        self.conv1 = nn.Conv2d(3, 192, kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(192, eps=1e-4, momentum=0.1, affine=True)

        self.conv2_1 = XNORConvBNReLU(192,
                                      160,
                                      kernel_size=1,
                                      stride=1,