def test_double_stash_pop_but_isolated():
    @skippable(stash=['foo'])
    class Layer1(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer2(nn.Module):
        pass

    @skippable(stash=['foo'])
    class Layer3(nn.Module):
        pass

    @skippable(pop=['foo'])
    class Layer4(nn.Module):
        pass

    ns1 = Namespace()
    ns2 = Namespace()

    verify_skippables(
        nn.Sequential(
            Layer1().isolate(ns1),
            Layer2().isolate(ns1),
            Layer3().isolate(ns2),
            Layer4().isolate(ns2),
        ))
Esempio n. 2
0
def test_namespace():
    ns1 = Namespace()
    ns2 = Namespace()

    p1 = nn.Sequential(StashFoo().isolate(ns1))
    p2 = nn.Sequential(StashFoo().isolate(ns2))
    p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))

    layout = inspect_skip_layout([p1, p2, p3])
    policy = [list(layout.copy_policy(i)) for i in range(3)]

    # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
    assert policy == [[], [], [(0, ns1, 'foo'), (1, ns2, 'foo')]]
Esempio n. 3
0
def bottleneck(inplanes: int,
               planes: int,
               stride: int = 1,
               downsample: Optional[nn.Module] = None,
               inplace: bool = False,
               ) -> nn.Sequential:
    """Creates a bottleneck block in ResNet as a :class:`nn.Sequential`."""

    layers: NamedModules = OrderedDict()

    ns = Namespace()
    layers['identity'] = Identity().isolate(ns)  # type: ignore

    layers['conv1'] = conv1x1(inplanes, planes)
    layers['bn1'] = nn.BatchNorm2d(planes)
    layers['relu1'] = nn.ReLU(inplace=inplace)

    layers['conv2'] = conv3x3(planes, planes, stride)
    layers['bn2'] = nn.BatchNorm2d(planes)
    layers['relu2'] = nn.ReLU(inplace=inplace)

    layers['conv3'] = conv1x1(planes, planes * 4)
    layers['bn3'] = nn.BatchNorm2d(planes * 4)
    layers['residual'] = Residual(downsample).isolate(ns)  # type: ignore
    layers['relu3'] = nn.ReLU(inplace=inplace)

    return nn.Sequential(layers)
Esempio n. 4
0
def basicblock(inplanes: int,
               planes: int,
               stride: int = 1,
               downsample: Optional[nn.Module] = None,
               inplace: bool = False,
               ) -> nn.Sequential:
    layers: NamedModules = OrderedDict()
    
    ns = Namespace()
    layers['identity'] = Identity().isolate(ns)  # type: ignore

    layers['conv1'] = conv3x3(inplanes, planes, stride)
    layers['bn1'] = nn.BatchNorm2d(planes)
    layers['relu1'] = nn.ReLU(inplace=inplace)

    layers['conv2'] = conv3x3(planes, planes)
    layers['bn2'] = nn.BatchNorm2d(planes)
    layers['residual'] = Residual(downsample).isolate(ns)  # type: ignore
    layers['relu3'] = nn.ReLU(inplace=inplace)

    return nn.Sequential(layers)
Esempio n. 5
0
def block(in_planes, out_planes, expansion, stride):
    planes = expansion * in_planes
    layers = OrderedDict()

    ns = Namespace()
    layers['identity'] = Identity().isolate(ns)  # type: ignore

    layers['conv1'] = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
    layers['bn1'] = nn.BatchNorm2d(planes)
    layers['relu1'] = nn.ReLU(inplace=False)

    layers['conv2'] = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
    layers['bn2'] = nn.BatchNorm2d(planes)
    layers['relu2'] = nn.ReLU(inplace=False)

    layers['conv3'] = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
    layers['bn3'] = nn.BatchNorm2d(out_planes)

    layers['shortcut'] = Shortcut(in_planes, out_planes, stride).isolate(ns)

    return nn.Sequential(layers)
Esempio n. 6
0
def unet(
    depth: int = 5,
    num_convs: int = 5,
    base_channels: int = 64,
    input_channels: int = 3,
    output_channels: int = 1,
) -> nn.Sequential:
    """Builds a simplified U-Net model."""
    # The U-Net structure
    encoder_channels = [{
        'in':
        input_channels if i == 0 else base_channels * (2**(i - 1)),
        'mid':
        base_channels * (2**i),
        'out':
        base_channels * (2**i),
    } for i in range(depth)]

    bottleneck_channels = [{
        'in': base_channels * (2**(depth - 1)),
        'mid': base_channels * (2**depth),
        'out': base_channels * (2**(depth - 1)),
    }]

    inverted_decoder_channels = [{
        'in': base_channels * (2**(i + 1)),
        'mid': int(base_channels * (2**(i - 1))),
        'out': int(base_channels * (2**(i - 1))),
    } for i in range(depth)]

    # Build cells.
    def cell(ch: Dict[str, int]) -> nn.Sequential:
        return stacked_convs(ch['in'], ch['mid'], ch['out'], num_convs)

    encoder_cells = [cell(c) for c in encoder_channels]
    bottleneck_cells = [cell(c) for c in bottleneck_channels]
    decoder_cells = [cell(c) for c in inverted_decoder_channels]

    # Link long skip connections.
    #
    # [ encoder ]--------------[ decoder ]--[ segment ]
    #    [ encoder ]--------[ decoder ]
    #       [ encoder ]--[ decoder ]
    #            [ bottleneck ]
    #
    namespaces = [Namespace() for _ in range(depth)]

    encoder_layers: List[nn.Module] = []
    for i in range(depth):
        ns = namespaces[i]
        encoder_layers.append(
            nn.Sequential(
                OrderedDict([
                    ('encode', encoder_cells[i]),
                    ('skip', Stash().isolate(ns)),  # type: ignore
                    ('down', nn.MaxPool2d(2, stride=2))
                ])))
    encoder = nn.Sequential(*encoder_layers)

    bottleneck = nn.Sequential(*bottleneck_cells)

    decoder_layers: List[nn.Module] = []
    for i in reversed(range(depth)):
        ns = namespaces[i]
        decoder_layers.append(
            nn.Sequential(
                OrderedDict([
                    ('up', nn.Upsample(scale_factor=2)),
                    ('skip', PopCat().isolate(ns)),  # type: ignore
                    ('decode', decoder_cells[i])
                ])))
    decoder = nn.Sequential(*decoder_layers)

    final_channels = inverted_decoder_channels[0]['out']
    segment = nn.Conv2d(final_channels,
                        output_channels,
                        kernel_size=1,
                        bias=False)

    # Construct a U-Net model as nn.Sequential.
    model = nn.Sequential(
        OrderedDict([('encoder', encoder), ('bottleneck', bottleneck),
                     ('decoder', decoder), ('segment', segment)]))
    model = flatten_sequential(model)
    return model
Esempio n. 7
0
def test_namespace_difference():
    ns1 = Namespace()
    ns2 = Namespace()
    assert ns1 != ns2
Esempio n. 8
0
def test_namespace_copy():
    ns = Namespace()
    assert copy.copy(ns) == ns
    assert copy.copy(ns) is not ns
Esempio n. 9
0
    def __init__(self,
                 spatial_dims: int,
                 in_channels: int,
                 out_channels: int,
                 n_feat: int = 32,
                 depth: int = 4):
        """
        A UNet-like architecture for model parallelism.

        Args:
            spatial_dims: number of input spatial dimensions,
                2 for (B, in_channels, H, W), 3 for (B, in_channels, H, W, D).
            in_channels: number of input channels.
            out_channels: number of output channels.
            n_feat: number of features in the first convolution.
            depth: number of downsampling stages.
        """
        super(UNetPipe, self).__init__()
        n_enc_filter: List[int] = [n_feat]
        for i in range(1, depth + 1):
            n_enc_filter.append(min(n_enc_filter[-1] * 2, 1024))
        namespaces = [Namespace() for _ in range(depth)]

        # construct the encoder
        encoder_layers: List[nn.Module] = []
        init_conv = Convolution(
            spatial_dims,
            in_channels,
            n_enc_filter[0],
            strides=2,
            act=Act.LEAKYRELU,
            norm=Norm.BATCH,
            bias=False,
        )
        encoder_layers.append(
            nn.Sequential(
                OrderedDict([(
                    "Conv",
                    init_conv,
                ), ("skip", Stash().isolate(namespaces[0]))])))
        for i in range(1, depth + 1):
            down_conv = DoubleConv(spatial_dims, n_enc_filter[i - 1],
                                   n_enc_filter[i])
            if i == depth:
                layer_dict = OrderedDict([("Down", down_conv)])
            else:
                layer_dict = OrderedDict([("Down", down_conv),
                                          ("skip",
                                           Stash().isolate(namespaces[i]))])
            encoder_layers.append(nn.Sequential(layer_dict))
        encoder = nn.Sequential(*encoder_layers)

        # construct the decoder
        decoder_layers: List[nn.Module] = []
        for i in reversed(range(1, depth + 1)):
            in_ch, out_ch = n_enc_filter[i], n_enc_filter[i - 1]
            layer_dict = OrderedDict([
                ("Up", UpSample(spatial_dims, in_ch, out_ch, 2, True)),
                ("skip", PopCat().isolate(namespaces[i - 1])),
                ("Conv1x1x1", Conv[Conv.CONV, spatial_dims](out_ch * 2,
                                                            in_ch,
                                                            kernel_size=1)),
                ("Conv",
                 DoubleConv(spatial_dims,
                            in_ch,
                            out_ch,
                            stride=1,
                            conv_only=True)),
            ])
            decoder_layers.append(nn.Sequential(layer_dict))
        in_ch = min(n_enc_filter[0] // 2, 32)
        layer_dict = OrderedDict([
            ("Up", UpSample(spatial_dims, n_feat, in_ch, 2, True)),
            ("RELU", Act[Act.LEAKYRELU](inplace=False)),
            (
                "out",
                Conv[Conv.CONV, spatial_dims](in_ch,
                                              out_channels,
                                              kernel_size=3,
                                              padding=1),
            ),
        ])
        decoder_layers.append(nn.Sequential(layer_dict))
        decoder = nn.Sequential(*decoder_layers)

        # making a sequential model
        self.add_module("encoder", encoder)
        self.add_module("decoder", decoder)

        for m in self.modules():
            if isinstance(m, Conv[Conv.CONV, spatial_dims]):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, Norm[Norm.BATCH, spatial_dims]):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, Conv[Conv.CONVTRANS, spatial_dims]):
                nn.init.kaiming_normal_(m.weight)