def __init__( self, conv3d_eq: nn.Conv3d, input_THW_tuple: Tuple, ): """ Args: conv3d_eq (nn.Module): input nn.Conv3d module to be converted into equivalent conv2d. input_THW_tuple (tuple): input THW size for conv3d_eq during forward. """ super().__init__() # create equivalent conv2d module in_channels = conv3d_eq.in_channels out_channels = conv3d_eq.out_channels bias_flag = conv3d_eq.bias is not None self.conv2d_eq = nn.Conv2d( in_channels, out_channels, kernel_size=(conv3d_eq.kernel_size[1], conv3d_eq.kernel_size[2]), stride=(conv3d_eq.stride[1], conv3d_eq.stride[2]), groups=conv3d_eq.groups, bias=bias_flag, padding=(conv3d_eq.padding[1], conv3d_eq.padding[2]), dilation=(conv3d_eq.dilation[1], conv3d_eq.dilation[2]), ) state_dict = conv3d_eq.state_dict() state_dict["weight"] = state_dict["weight"].squeeze(2) self.conv2d_eq.load_state_dict(state_dict) self.input_THW_tuple = input_THW_tuple
def __init__( self, conv3d_in: nn.Conv3d, input_THW_tuple: Tuple, ): """ Args: conv3d_in (nn.Module): input nn.Conv3d module to be converted into equivalent conv2d. input_THW_tuple (tuple): input THW size for conv3d_in during forward. """ super().__init__() assert conv3d_in.padding[0] == 1, ( "_Conv3dTemporalKernel3Eq only support temporal padding of 1, " f"but got {conv3d_in.padding[0]}") assert conv3d_in.padding_mode == "zeros", ( "_Conv3dTemporalKernel3Eq only support zero padding, " f"but got {conv3d_in.padding_mode}") self._input_THW_tuple = input_THW_tuple padding_2d = conv3d_in.padding[1:] in_channels = conv3d_in.in_channels out_channels = conv3d_in.out_channels kernel_size = conv3d_in.kernel_size[1:] groups = conv3d_in.groups stride_2d = conv3d_in.stride[1:] # Create 3 conv2d to emulate conv3d. if (self._input_THW_tuple[0] > 1 ): # Those two conv2d are needed only when temporal input > 1. self._conv2d_3_3_0 = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=padding_2d, stride=stride_2d, groups=groups, bias=False, ) self._conv2d_3_3_2 = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=padding_2d, stride=stride_2d, groups=groups, bias=False, ) self._conv2d_3_3_1 = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=padding_2d, stride=stride_2d, groups=groups, bias=(conv3d_in.bias is not None), ) state_dict = conv3d_in.state_dict() state_dict_1 = deepcopy(state_dict) state_dict_1["weight"] = state_dict["weight"][:, :, 1] self._conv2d_3_3_1.load_state_dict(state_dict_1) if self._input_THW_tuple[0] > 1: state_dict_0 = deepcopy(state_dict) state_dict_0["weight"] = state_dict["weight"][:, :, 0] if conv3d_in.bias is not None: """ Don't need bias for other conv2d instances to avoid duplicated addition of bias. """ state_dict_0.pop("bias") self._conv2d_3_3_0.load_state_dict(state_dict_0) state_dict_2 = deepcopy(state_dict) state_dict_2["weight"] = state_dict["weight"][:, :, 2] if conv3d_in.bias is not None: state_dict_2.pop("bias") self._conv2d_3_3_2.load_state_dict(state_dict_2) self._add_funcs = nn.ModuleList([ nn.quantized.FloatFunctional() for _ in range(2 * (self._input_THW_tuple[0] - 1)) ]) self._cat_func = nn.quantized.FloatFunctional()
def __init__( self, conv3d_in: nn.Conv3d, thw_shape: Tuple[int, int, int], ): """ Args: conv3d_in (nn.Module): input nn.Conv3d module to be converted into equivalent conv2d. thw_shape (tuple): input THW size for conv3d_in during forward. """ super().__init__() assert conv3d_in.padding[0] == 2, ( "_Conv3dTemporalKernel5Eq only support temporal padding of 2, " f"but got {conv3d_in.padding[0]}") assert conv3d_in.padding_mode == "zeros", ( "_Conv3dTemporalKernel5Eq only support zero padding, " f"but got {conv3d_in.padding_mode}") self._thw_shape = thw_shape padding_2d = conv3d_in.padding[1:] in_channels = conv3d_in.in_channels out_channels = conv3d_in.out_channels kernel_size = conv3d_in.kernel_size[1:] groups = conv3d_in.groups stride_2d = conv3d_in.stride[1:] # Create 3 conv2d to emulate conv3d. t, h, w = self._thw_shape args_dict = { "in_channels": in_channels, "out_channels": out_channels, "kernel_size": kernel_size, "padding": padding_2d, "stride": stride_2d, "groups": groups, } for iter_idx in range(5): if iter_idx != 2: if t > 1: # Those four conv2d are needed only when temporal input > 1. self.add_module(f"_conv2d_{iter_idx}", nn.Conv2d(**args_dict, bias=False)) else: # _conv2d_2 is needed for all circumstances. self.add_module( f"_conv2d_{iter_idx}", nn.Conv2d(**args_dict, bias=(conv3d_in.bias is not None)), ) # State dict for _conv2d_2 original_state_dict = conv3d_in.state_dict() state_dict_to_load = deepcopy(original_state_dict) state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 2] self._conv2d_2.load_state_dict(state_dict_to_load) if t > 1: if conv3d_in.bias is not None: # Don't need bias for other conv2d instances to avoid duplicated # addition of bias. state_dict_to_load.pop("bias") # State dict for _conv2d_0, _conv2d_1, _conv2d_3, _conv2d_4 state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 0] self._conv2d_0.load_state_dict(state_dict_to_load) state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 1] self._conv2d_1.load_state_dict(state_dict_to_load) state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 3] self._conv2d_3.load_state_dict(state_dict_to_load) state_dict_to_load["weight"] = original_state_dict["weight"][:, :, 4] self._conv2d_4.load_state_dict(state_dict_to_load) # Elementwise add are needed in forward function, use nn.quantized.FloatFunctional() # for better quantization support. One convolution needs at most 4 elementwise adds # without zero padding; for boundary planes fewer elementwise adds are needed. # See forward() for more details. self._add_funcs = nn.ModuleList( [nn.quantized.FloatFunctional() for _ in range(4 * t - 6)]) self._cat_func = nn.quantized.FloatFunctional()