def __init__(self, in_planes, out_planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, activation="ReLU", norm_num_groups=8): """ Basic block initializer. Args: in_planes (int): The number of input channels. out_planes (int): The number of output channels. stride (int): The convolution stride. Controls the stride for the cross-correlation downsample (:obj:`torch.nn.Sequential`): A sequential downsampling sub-layer. groups (int): The number of groups in the convolution. Controls the connections between input and outputs. base_width (int): The base number of output channels. dilation (int): Controls the spacing between the kernel points activation (str): Desired non-linear activation function. norm_num_groups (int): The number of groups for group normalization. """ super(BasicBlock, self).__init__() self._activation_layer_factory = ActivationLayerFactory() self._normalization_layer_factory = NormalizationLayerFactory() if norm_num_groups is not None: self._norm1 = self._normalization_layer_factory.create( NormalizationLayers.GroupNorm, norm_num_groups, out_planes) self._norm2 = self._normalization_layer_factory.create( NormalizationLayers.GroupNorm, norm_num_groups, out_planes) else: self._norm1 = self._normalization_layer_factory.create( NormalizationLayers.BatchNorm3d, out_planes) self._norm2 = self._normalization_layer_factory.create( NormalizationLayers.BatchNorm3d, out_planes) if groups != 1 or base_width != 64: raise ValueError( 'BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError( "Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self._conv1 = conv3x3(in_planes, out_planes, stride) if activation == ActivationLayers.PReLU.name: self._activation = self._activation_layer_factory.create( ActivationLayers.PReLU) else: self._activation = self._activation_layer_factory.create( ActivationLayers.ReLU, inplace=True) self._conv2 = conv3x3(out_planes, out_planes) self._downsample = downsample self._stride = stride
def __init__(self, in_planes, out_planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, activation="ReLU", norm_num_groups=8): """ Bottleneck Block initializer. Args: in_planes (int): The number of input channels. out_planes (int): The number of output channels. stride (int): The convolution stride. Controls the stride for the cross-correlation downsample (:obj:`torch.nn.Sequential`): A sequential downsampling sub-layer. groups (int): The number of groups in the convolution. Controls the connections between input and outputs. base_width (int): The base number of output channels. dilation (int): Controls the spacing between the kernel points activation (str): Desired non-linear activation function. norm_num_groups (int): The number of groups for group normalization. """ super(Bottleneck, self).__init__() self._activation_layer_factory = ActivationLayerFactory() self._normalization_layer_factory = NormalizationLayerFactory() width = int(out_planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 if norm_num_groups is not None: self._norm1 = self._normalization_layer_factory.create(NormalizationLayers.GroupNorm, norm_num_groups, width) self._norm2 = self._normalization_layer_factory.create(NormalizationLayers.GroupNorm, norm_num_groups, width) self._norm3 = self._normalization_layer_factory.create(NormalizationLayers.GroupNorm, norm_num_groups, out_planes * self.expansion) else: self._norm1 = self._normalization_layer_factory.create(NormalizationLayers.BatchNorm3d, width) self._norm2 = self._normalization_layer_factory.create(NormalizationLayers.BatchNorm3d, width) self._norm3 = self._normalization_layer_factory.create(NormalizationLayers.BatchNorm3d, out_planes * self.expansion) self._conv1 = conv1x1(in_planes, width) self._conv2 = conv3x3(width, width, stride, groups, dilation) self._conv3 = conv1x1(width, out_planes * self.expansion) if activation == ActivationLayers.PReLU.name: self._activation = self._activation_layer_factory.create(ActivationLayers.PReLU) else: self._activation = self._activation_layer_factory.create(ActivationLayers.ReLU, inplace=True) self._downsample = downsample self._stride = stride