Beispiel #1
0
class BasicBlock(torch.nn.Module):
    """
    A basic ResNet block.
    """
    expansion = 1

    def __init__(self,
                 in_planes,
                 out_planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 activation="ReLU",
                 norm_num_groups=8):
        """
        Basic block initializer.

        Args:
            in_planes (int): The number of input channels.
            out_planes (int): The number of output channels.
            stride (int): The convolution stride. Controls the stride for the cross-correlation
            downsample (:obj:`torch.nn.Sequential`): A sequential downsampling sub-layer.
            groups (int): The number of groups in the convolution. Controls the connections between input and outputs.
            base_width (int): The base number of output channels.
            dilation (int): Controls the spacing between the kernel points
            activation (str): Desired non-linear activation function.
            norm_num_groups (int): The number of groups for group normalization.
        """
        super(BasicBlock, self).__init__()
        self._activation_layer_factory = ActivationLayerFactory()
        self._normalization_layer_factory = NormalizationLayerFactory()

        if norm_num_groups is not None:
            self._norm1 = self._normalization_layer_factory.create(
                NormalizationLayers.GroupNorm, norm_num_groups, out_planes)
            self._norm2 = self._normalization_layer_factory.create(
                NormalizationLayers.GroupNorm, norm_num_groups, out_planes)
        else:
            self._norm1 = self._normalization_layer_factory.create(
                NormalizationLayers.BatchNorm3d, out_planes)
            self._norm2 = self._normalization_layer_factory.create(
                NormalizationLayers.BatchNorm3d, out_planes)

        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self._conv1 = conv3x3(in_planes, out_planes, stride)

        if activation == ActivationLayers.PReLU.name:
            self._activation = self._activation_layer_factory.create(
                ActivationLayers.PReLU)
        else:
            self._activation = self._activation_layer_factory.create(
                ActivationLayers.ReLU, inplace=True)

        self._conv2 = conv3x3(out_planes, out_planes)

        self._downsample = downsample
        self._stride = stride

    def forward(self, x):
        identity = x

        out = self._conv1(x)
        out = self._norm1(out)
        out = self._activation(out)

        out = self._conv2(out)
        out = self._norm2(out)

        if self._downsample is not None:
            identity = self._downsample(x)

        out += identity
        out = self._activation(out)

        return out
Beispiel #2
0
class Bottleneck(torch.nn.Module):
    """
    Bottleneck ResNet block.
    """
    expansion = 4

    def __init__(self,
                 in_planes,
                 out_planes,
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 activation="ReLU",
                 norm_num_groups=8):
        """
        Bottleneck Block initializer.

        Args:
            in_planes (int): The number of input channels.
            out_planes (int): The number of output channels.
            stride (int): The convolution stride. Controls the stride for the cross-correlation
            downsample (:obj:`torch.nn.Sequential`): A sequential downsampling sub-layer.
            groups (int): The number of groups in the convolution. Controls the connections between input and outputs.
            base_width (int): The base number of output channels.
            dilation (int): Controls the spacing between the kernel points
            activation (str): Desired non-linear activation function.
            norm_num_groups (int): The number of groups for group normalization.
        """
        super(Bottleneck, self).__init__()

        self._activation_layer_factory = ActivationLayerFactory()
        self._normalization_layer_factory = NormalizationLayerFactory()

        width = int(out_planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1

        if norm_num_groups is not None:
            self._norm1 = self._normalization_layer_factory.create(
                NormalizationLayers.GroupNorm, norm_num_groups, width)
            self._norm2 = self._normalization_layer_factory.create(
                NormalizationLayers.GroupNorm, norm_num_groups, width)
            self._norm3 = self._normalization_layer_factory.create(
                NormalizationLayers.GroupNorm, norm_num_groups,
                out_planes * self.expansion)
        else:
            self._norm1 = self._normalization_layer_factory.create(
                NormalizationLayers.BatchNorm3d, width)
            self._norm2 = self._normalization_layer_factory.create(
                NormalizationLayers.BatchNorm3d, width)
            self._norm3 = self._normalization_layer_factory.create(
                NormalizationLayers.BatchNorm3d, out_planes * self.expansion)

        self._conv1 = conv1x1(in_planes, width)

        self._conv2 = conv3x3(width, width, stride, groups, dilation)

        self._conv3 = conv1x1(width, out_planes * self.expansion)

        if activation == ActivationLayers.PReLU.name:
            self._activation = self._activation_layer_factory.create(
                ActivationLayers.PReLU)
        else:
            self._activation = self._activation_layer_factory.create(
                ActivationLayers.ReLU, inplace=True)

        self._downsample = downsample
        self._stride = stride

    def forward(self, x):
        identity = x

        out = self._conv1(x)
        out = self._norm1(out)
        out = self._activation(out)

        out = self._conv2(out)
        out = self._norm2(out)
        out = self._activation(out)

        out = self._conv3(out)
        out = self._norm3(out)

        if self._downsample is not None:
            identity = self._downsample(x)

        out += identity
        out = self._activation(out)

        return out