Ejemplo n.º 1
0
    def __init__(
        self,
        in_channels,
        out_channels,
        conv_args="conv",
        bn_args="bn",
        relu_args="relu",
        # additional arguments for conv
        **kwargs,
    ):
        super().__init__()
        conv_op = build_conv(
            in_channels=in_channels,
            out_channels=out_channels,
            **hp.merge_unify_args(conv_args, kwargs),
        )

        # register in order
        self.empty_input = build_empty_input_op(conv_op)
        self.conv = conv_op

        self.bn = (
            build_bn(num_channels=out_channels, **hp.unify_args(bn_args))
            if bn_args is not None
            else None
        )
        self.relu = (
            build_relu(num_channels=out_channels, **hp.unify_args(relu_args))
            if relu_args is not None
            else None
        )

        self.out_channels = out_channels
Ejemplo n.º 2
0
 def __init__(self, width_ratio=1.0, bn_args="bn", width_divisor=1):
     self.width_ratio = width_ratio
     self.last_depth = -1
     self.width_divisor = width_divisor
     # basic arguments that will be provided to all primitivies, they could be
     #   overrided by primitive parameters
     self.basic_args = {
         "bn_args": hp.unify_args(bn_args),
         "width_divisor": width_divisor,
     }
Ejemplo n.º 3
0
    def __init__(
        self,
        in_channels,
        out_channels,
        conv_args="conv",
        bn_args="bn",
        relu_args="relu",
        upsample_args="default",
        # additional arguments for conv
        **kwargs,
    ):
        super().__init__()

        conv_full_args = hp.merge_unify_args(conv_args, kwargs)
        conv_stride = conv_full_args.pop("stride", 1)
        # build upsample op if stride is negative
        upsample_op, conv_stride = build_upsample_neg_stride(
            stride=conv_stride, **hp.unify_args(upsample_args))
        # build conv
        conv_op = build_conv(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=conv_stride,
            **conv_full_args,
        )

        # register in order
        self.conv = conv_op

        self.bn = (build_bn(num_channels=out_channels,
                            **hp.unify_args(bn_args))
                   if bn_args is not None and len(bn_args) != 0 else None)
        self.relu = (build_relu(num_channels=out_channels,
                                **hp.unify_args(relu_args))
                     if relu_args is not None else None)
        self.upsample = upsample_op

        self.out_channels = out_channels
Ejemplo n.º 4
0
 def __init__(self,
              width_ratio=1.0,
              bn_args=None,
              width_divisor=1,
              basic_args=None):
     self.width_ratio = width_ratio
     self.last_depth = -1
     # basic arguments that will be provided to all primitivies, they could be
     #   overrided by primitive parameters
     self.basic_args = {
         **({
             "bn_args": hp.unify_args(bn_args)
         } if bn_args else {}),
         "width_divisor":
         width_divisor,
     }
     if basic_args is not None:
         assert isinstance(basic_args, dict)
         self.add_basic_args(**basic_args)
Ejemplo n.º 5
0
    def __init__(
        self,
        in_channels,
        out_channels,
        expansion=6,
        kernel_size=-1,
        stride=1,
        bias=False,
        pw_args="conv",
        pwl_args="conv",
        pool_args=None,
        bn_args="bn",
        relu_args="relu",
        se_args=None,  # se after pool
        pw_se_args=None,  # se after pw conv
        res_conn_args="default",
        width_divisor=8,
        always_pw=False,
        less_se_channels=False,
    ):
        super().__init__()

        mid_channels = hp.get_divisible_by(in_channels * expansion, width_divisor)

        self.pw = None
        if in_channels != mid_channels or always_pw:
            self.pw = bb.ConvBNRelu(
                in_channels=in_channels,
                out_channels=mid_channels,
                conv_args={
                    "kernel_size": 1,
                    "stride": 1,
                    "padding": 0,
                    "bias": bias,
                    **hp.unify_args(pw_args),
                },
                bn_args=bn_args,
                relu_args=relu_args,
            )

        pw_se_ratio = 0.25
        if less_se_channels:
            pw_se_ratio /= expansion
        self.pw_se = bb.build_se(
            in_channels=mid_channels,
            mid_channels=(mid_channels * pw_se_ratio),
            width_divisor=width_divisor,
            **hp.merge_unify_args({"relu_args": relu_args}, pw_se_args)
        )

        if kernel_size == -1:
            self.dw = nn.AdaptiveAvgPool2d(1)
        else:
            self.dw = nn.AvgPool2d(
                kernel_size, stride=stride, **hp.unify_args(pool_args)
            )

        se_ratio = 0.25
        if less_se_channels:
            se_ratio /= expansion
        self.se = bb.build_se(
            in_channels=mid_channels,
            mid_channels=(mid_channels * se_ratio),
            width_divisor=width_divisor,
            **hp.merge_unify_args({"relu_args": relu_args}, se_args)
        )
        self.pwl = bb.ConvBNRelu(
            in_channels=mid_channels,
            out_channels=out_channels,
            conv_args={
                "kernel_size": 1,
                "stride": 1,
                "padding": 0,
                "bias": bias,
                **hp.merge_unify_args(pwl_args),
            },
            # no bn
            bn_args=None,
            # has relu
            relu_args=relu_args,
        )
        self.res_conn = bb.build_residual_connect(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride,
            **hp.unify_args(res_conn_args)
        )
        self.out_channels = out_channels
Ejemplo n.º 6
0
    def __init__(
        self,
        in_channels,
        out_channels,
        expansion=6,
        kernel_size=3,
        stride=1,
        bias=False,
        conv_args="conv",
        bn_args="bn",
        relu_args="relu",
        se_args=None,
        round_se_channels=False,
        res_conn_args="default",
        upsample_args="default",
        width_divisor=8,
        pw_args=None,
        pw_bn_args=None,
        dw_args=None,
        dw_bn_args=None,
        pwl_args=None,
        pwl_bn_args=None,
        dw_skip_bnrelu=False,
        skip_pwl_bn=False,
        pw_groups=1,
        dw_group_ratio=1,  # dw_group == mid_channels // dw_group_ratio
        pwl_groups=1,
        always_pw=False,
        less_se_channels=False,
        zero_last_bn_gamma=False,
        drop_connect_rate=None,
    ):
        super().__init__()

        conv_args = hp.unify_args(conv_args)
        bn_args = hp.unify_args(bn_args)
        relu_args = hp.unify_args(relu_args)

        mid_channels = hp.get_divisible_by(in_channels * expansion, width_divisor)

        res_conn = bb.build_residual_connect(
            in_channels=in_channels,
            out_channels=out_channels,
            stride=stride,
            drop_connect_rate=drop_connect_rate,
            **hp.unify_args(res_conn_args)
        )

        self.pw = None
        self.shuffle = None
        if in_channels != mid_channels or always_pw:
            self.pw = bb.ConvBNRelu(
                in_channels=in_channels,
                out_channels=mid_channels,
                conv_args={
                    "kernel_size": 1,
                    "stride": 1,
                    "padding": 0,
                    "bias": bias,
                    "groups": pw_groups,
                    **hp.merge_unify_args(conv_args, pw_args),
                },
                bn_args=hp.merge_unify_args(bn_args, pw_bn_args),
                relu_args=relu_args,
            )
        if pw_groups > 1:
            self.shuffle = bb.ChannelShuffle(pw_groups)
        # use negative stride for upsampling
        self.upsample, dw_stride = bb.build_upsample_neg_stride(
            stride=stride, **hp.unify_args(upsample_args)
        )

        dw_padding = (
            kernel_size // 2
            if not (dw_args and "padding" in dw_args)
            else dw_args.pop("padding")
        )
        self.dw = bb.ConvBNRelu(
            in_channels=mid_channels,
            out_channels=mid_channels,
            conv_args={
                "kernel_size": kernel_size,
                "stride": dw_stride,
                "padding": dw_padding,
                "groups": mid_channels // dw_group_ratio,
                "bias": bias,
                **hp.merge_unify_args(conv_args, dw_args),
            },
            bn_args=hp.merge_unify_args(bn_args, dw_bn_args)
            if not dw_skip_bnrelu
            else None,
            relu_args=relu_args if not dw_skip_bnrelu else None,
        )
        se_ratio = 0.25
        if less_se_channels:
            se_ratio /= expansion
        self.se = bb.build_se(
            in_channels=mid_channels,
            mid_channels=(
                int(mid_channels * se_ratio)
                if not round_se_channels
                else round(mid_channels * se_ratio)
            ),
            width_divisor=width_divisor,
            **hp.merge(relu_args=relu_args, kwargs=hp.unify_args(se_args))
        )
        self.pwl = bb.ConvBNRelu(
            in_channels=mid_channels,
            out_channels=out_channels,
            conv_args={
                "kernel_size": 1,
                "stride": 1,
                "padding": 0,
                "bias": bias,
                "groups": pwl_groups,
                **hp.merge_unify_args(conv_args, pwl_args),
            },
            bn_args={
                **hp.merge_unify_args(bn_args, pwl_bn_args),
                **{
                    "zero_gamma": (
                        zero_last_bn_gamma if res_conn is not None else False
                    )
                },
            }
            if not skip_pwl_bn
            else None,
            relu_args=None,
        )

        self.res_conn = res_conn
        self.out_channels = out_channels