예제 #1
0
 def __init__(
     self,
     conv3d_eq: nn.Conv3d,
     input_THW_tuple: Tuple,
 ):
     """
     Args:
         conv3d_eq (nn.Module): input nn.Conv3d module to be converted
             into equivalent conv2d.
         input_THW_tuple (tuple): input THW size for conv3d_eq during forward.
     """
     super().__init__()
     # create equivalent conv2d module
     in_channels = conv3d_eq.in_channels
     out_channels = conv3d_eq.out_channels
     bias_flag = conv3d_eq.bias is not None
     self.conv2d_eq = nn.Conv2d(
         in_channels,
         out_channels,
         kernel_size=(conv3d_eq.kernel_size[1], conv3d_eq.kernel_size[2]),
         stride=(conv3d_eq.stride[1], conv3d_eq.stride[2]),
         groups=conv3d_eq.groups,
         bias=bias_flag,
         padding=(conv3d_eq.padding[1], conv3d_eq.padding[2]),
         dilation=(conv3d_eq.dilation[1], conv3d_eq.dilation[2]),
     )
     state_dict = conv3d_eq.state_dict()
     state_dict["weight"] = state_dict["weight"].squeeze(2)
     self.conv2d_eq.load_state_dict(state_dict)
     self.input_THW_tuple = input_THW_tuple
예제 #2
0
    def __init__(
        self,
        conv3d_in: nn.Conv3d,
        input_THW_tuple: Tuple,
    ):
        """
        Args:
            conv3d_in (nn.Module): input nn.Conv3d module to be converted
                into equivalent conv2d.
            input_THW_tuple (tuple): input THW size for conv3d_in during forward.
        """
        super().__init__()
        assert conv3d_in.padding[0] == 1, (
            "_Conv3dTemporalKernel3Eq only support temporal padding of 1, "
            f"but got {conv3d_in.padding[0]}")
        assert conv3d_in.padding_mode == "zeros", (
            "_Conv3dTemporalKernel3Eq only support zero padding, "
            f"but got {conv3d_in.padding_mode}")
        self._input_THW_tuple = input_THW_tuple
        padding_2d = conv3d_in.padding[1:]
        in_channels = conv3d_in.in_channels
        out_channels = conv3d_in.out_channels
        kernel_size = conv3d_in.kernel_size[1:]
        groups = conv3d_in.groups
        stride_2d = conv3d_in.stride[1:]
        # Create 3 conv2d to emulate conv3d.
        if (self._input_THW_tuple[0] > 1
            ):  # Those two conv2d are needed only when temporal input > 1.
            self._conv2d_3_3_0 = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding_2d,
                stride=stride_2d,
                groups=groups,
                bias=False,
            )
            self._conv2d_3_3_2 = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                padding=padding_2d,
                stride=stride_2d,
                groups=groups,
                bias=False,
            )
        self._conv2d_3_3_1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=padding_2d,
            stride=stride_2d,
            groups=groups,
            bias=(conv3d_in.bias is not None),
        )

        state_dict = conv3d_in.state_dict()
        state_dict_1 = deepcopy(state_dict)
        state_dict_1["weight"] = state_dict["weight"][:, :, 1]
        self._conv2d_3_3_1.load_state_dict(state_dict_1)

        if self._input_THW_tuple[0] > 1:
            state_dict_0 = deepcopy(state_dict)
            state_dict_0["weight"] = state_dict["weight"][:, :, 0]
            if conv3d_in.bias is not None:
                """
                Don't need bias for other conv2d instances to avoid duplicated addition of bias.
                """
                state_dict_0.pop("bias")
            self._conv2d_3_3_0.load_state_dict(state_dict_0)

            state_dict_2 = deepcopy(state_dict)
            state_dict_2["weight"] = state_dict["weight"][:, :, 2]
            if conv3d_in.bias is not None:
                state_dict_2.pop("bias")
            self._conv2d_3_3_2.load_state_dict(state_dict_2)

            self._add_funcs = nn.ModuleList([
                nn.quantized.FloatFunctional()
                for _ in range(2 * (self._input_THW_tuple[0] - 1))
            ])
            self._cat_func = nn.quantized.FloatFunctional()
예제 #3
0
    def __init__(
        self,
        conv3d_in: nn.Conv3d,
        thw_shape: Tuple[int, int, int],
    ):
        """
        Args:
            conv3d_in (nn.Module): input nn.Conv3d module to be converted
                into equivalent conv2d.
            thw_shape (tuple): input THW size for conv3d_in during forward.
        """
        super().__init__()
        assert conv3d_in.padding[0] == 2, (
            "_Conv3dTemporalKernel5Eq only support temporal padding of 2, "
            f"but got {conv3d_in.padding[0]}")
        assert conv3d_in.padding_mode == "zeros", (
            "_Conv3dTemporalKernel5Eq only support zero padding, "
            f"but got {conv3d_in.padding_mode}")
        self._thw_shape = thw_shape
        padding_2d = conv3d_in.padding[1:]
        in_channels = conv3d_in.in_channels
        out_channels = conv3d_in.out_channels
        kernel_size = conv3d_in.kernel_size[1:]
        groups = conv3d_in.groups
        stride_2d = conv3d_in.stride[1:]
        # Create 3 conv2d to emulate conv3d.
        t, h, w = self._thw_shape
        args_dict = {
            "in_channels": in_channels,
            "out_channels": out_channels,
            "kernel_size": kernel_size,
            "padding": padding_2d,
            "stride": stride_2d,
            "groups": groups,
        }

        for iter_idx in range(5):
            if iter_idx != 2:
                if t > 1:  # Those four conv2d are needed only when temporal input > 1.
                    self.add_module(f"_conv2d_{iter_idx}",
                                    nn.Conv2d(**args_dict, bias=False))
            else:  # _conv2d_2 is needed for all circumstances.
                self.add_module(
                    f"_conv2d_{iter_idx}",
                    nn.Conv2d(**args_dict, bias=(conv3d_in.bias is not None)),
                )

        # State dict for _conv2d_2
        original_state_dict = conv3d_in.state_dict()
        state_dict_to_load = deepcopy(original_state_dict)
        state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 2]
        self._conv2d_2.load_state_dict(state_dict_to_load)

        if t > 1:
            if conv3d_in.bias is not None:
                # Don't need bias for other conv2d instances to avoid duplicated
                # addition of bias.
                state_dict_to_load.pop("bias")
            # State dict for _conv2d_0, _conv2d_1, _conv2d_3, _conv2d_4
            state_dict_to_load["weight"] = original_state_dict["weight"][:, :,
                                                                         0]
            self._conv2d_0.load_state_dict(state_dict_to_load)

            state_dict_to_load["weight"] = original_state_dict["weight"][:, :,
                                                                         1]
            self._conv2d_1.load_state_dict(state_dict_to_load)

            state_dict_to_load["weight"] = original_state_dict["weight"][:, :,
                                                                         3]
            self._conv2d_3.load_state_dict(state_dict_to_load)

            state_dict_to_load["weight"] = original_state_dict["weight"][:, :,
                                                                         4]
            self._conv2d_4.load_state_dict(state_dict_to_load)
            # Elementwise add are needed in forward function, use nn.quantized.FloatFunctional()
            # for better quantization support. One convolution needs at most 4 elementwise adds
            # without zero padding; for boundary planes fewer elementwise adds are needed.
            # See forward() for more details.
            self._add_funcs = nn.ModuleList(
                [nn.quantized.FloatFunctional() for _ in range(4 * t - 6)])
            self._cat_func = nn.quantized.FloatFunctional()