def test_double_stash_pop_but_isolated(): @skippable(stash=['foo']) class Layer1(nn.Module): pass @skippable(pop=['foo']) class Layer2(nn.Module): pass @skippable(stash=['foo']) class Layer3(nn.Module): pass @skippable(pop=['foo']) class Layer4(nn.Module): pass ns1 = Namespace() ns2 = Namespace() verify_skippables( nn.Sequential( Layer1().isolate(ns1), Layer2().isolate(ns1), Layer3().isolate(ns2), Layer4().isolate(ns2), ))
def test_namespace(): ns1 = Namespace() ns2 = Namespace() p1 = nn.Sequential(StashFoo().isolate(ns1)) p2 = nn.Sequential(StashFoo().isolate(ns2)) p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1)) layout = inspect_skip_layout([p1, p2, p3]) policy = [list(layout.copy_policy(i)) for i in range(3)] # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index. assert policy == [[], [], [(0, ns1, 'foo'), (1, ns2, 'foo')]]
def bottleneck(inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, inplace: bool = False, ) -> nn.Sequential: """Creates a bottleneck block in ResNet as a :class:`nn.Sequential`.""" layers: NamedModules = OrderedDict() ns = Namespace() layers['identity'] = Identity().isolate(ns) # type: ignore layers['conv1'] = conv1x1(inplanes, planes) layers['bn1'] = nn.BatchNorm2d(planes) layers['relu1'] = nn.ReLU(inplace=inplace) layers['conv2'] = conv3x3(planes, planes, stride) layers['bn2'] = nn.BatchNorm2d(planes) layers['relu2'] = nn.ReLU(inplace=inplace) layers['conv3'] = conv1x1(planes, planes * 4) layers['bn3'] = nn.BatchNorm2d(planes * 4) layers['residual'] = Residual(downsample).isolate(ns) # type: ignore layers['relu3'] = nn.ReLU(inplace=inplace) return nn.Sequential(layers)
def basicblock(inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, inplace: bool = False, ) -> nn.Sequential: layers: NamedModules = OrderedDict() ns = Namespace() layers['identity'] = Identity().isolate(ns) # type: ignore layers['conv1'] = conv3x3(inplanes, planes, stride) layers['bn1'] = nn.BatchNorm2d(planes) layers['relu1'] = nn.ReLU(inplace=inplace) layers['conv2'] = conv3x3(planes, planes) layers['bn2'] = nn.BatchNorm2d(planes) layers['residual'] = Residual(downsample).isolate(ns) # type: ignore layers['relu3'] = nn.ReLU(inplace=inplace) return nn.Sequential(layers)
def block(in_planes, out_planes, expansion, stride): planes = expansion * in_planes layers = OrderedDict() ns = Namespace() layers['identity'] = Identity().isolate(ns) # type: ignore layers['conv1'] = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) layers['bn1'] = nn.BatchNorm2d(planes) layers['relu1'] = nn.ReLU(inplace=False) layers['conv2'] = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) layers['bn2'] = nn.BatchNorm2d(planes) layers['relu2'] = nn.ReLU(inplace=False) layers['conv3'] = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) layers['bn3'] = nn.BatchNorm2d(out_planes) layers['shortcut'] = Shortcut(in_planes, out_planes, stride).isolate(ns) return nn.Sequential(layers)
def unet( depth: int = 5, num_convs: int = 5, base_channels: int = 64, input_channels: int = 3, output_channels: int = 1, ) -> nn.Sequential: """Builds a simplified U-Net model.""" # The U-Net structure encoder_channels = [{ 'in': input_channels if i == 0 else base_channels * (2**(i - 1)), 'mid': base_channels * (2**i), 'out': base_channels * (2**i), } for i in range(depth)] bottleneck_channels = [{ 'in': base_channels * (2**(depth - 1)), 'mid': base_channels * (2**depth), 'out': base_channels * (2**(depth - 1)), }] inverted_decoder_channels = [{ 'in': base_channels * (2**(i + 1)), 'mid': int(base_channels * (2**(i - 1))), 'out': int(base_channels * (2**(i - 1))), } for i in range(depth)] # Build cells. def cell(ch: Dict[str, int]) -> nn.Sequential: return stacked_convs(ch['in'], ch['mid'], ch['out'], num_convs) encoder_cells = [cell(c) for c in encoder_channels] bottleneck_cells = [cell(c) for c in bottleneck_channels] decoder_cells = [cell(c) for c in inverted_decoder_channels] # Link long skip connections. # # [ encoder ]--------------[ decoder ]--[ segment ] # [ encoder ]--------[ decoder ] # [ encoder ]--[ decoder ] # [ bottleneck ] # namespaces = [Namespace() for _ in range(depth)] encoder_layers: List[nn.Module] = [] for i in range(depth): ns = namespaces[i] encoder_layers.append( nn.Sequential( OrderedDict([ ('encode', encoder_cells[i]), ('skip', Stash().isolate(ns)), # type: ignore ('down', nn.MaxPool2d(2, stride=2)) ]))) encoder = nn.Sequential(*encoder_layers) bottleneck = nn.Sequential(*bottleneck_cells) decoder_layers: List[nn.Module] = [] for i in reversed(range(depth)): ns = namespaces[i] decoder_layers.append( nn.Sequential( OrderedDict([ ('up', nn.Upsample(scale_factor=2)), ('skip', PopCat().isolate(ns)), # type: ignore ('decode', decoder_cells[i]) ]))) decoder = nn.Sequential(*decoder_layers) final_channels = inverted_decoder_channels[0]['out'] segment = nn.Conv2d(final_channels, output_channels, kernel_size=1, bias=False) # Construct a U-Net model as nn.Sequential. model = nn.Sequential( OrderedDict([('encoder', encoder), ('bottleneck', bottleneck), ('decoder', decoder), ('segment', segment)])) model = flatten_sequential(model) return model
def test_namespace_difference(): ns1 = Namespace() ns2 = Namespace() assert ns1 != ns2
def test_namespace_copy(): ns = Namespace() assert copy.copy(ns) == ns assert copy.copy(ns) is not ns
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, n_feat: int = 32, depth: int = 4): """ A UNet-like architecture for model parallelism. Args: spatial_dims: number of input spatial dimensions, 2 for (B, in_channels, H, W), 3 for (B, in_channels, H, W, D). in_channels: number of input channels. out_channels: number of output channels. n_feat: number of features in the first convolution. depth: number of downsampling stages. """ super(UNetPipe, self).__init__() n_enc_filter: List[int] = [n_feat] for i in range(1, depth + 1): n_enc_filter.append(min(n_enc_filter[-1] * 2, 1024)) namespaces = [Namespace() for _ in range(depth)] # construct the encoder encoder_layers: List[nn.Module] = [] init_conv = Convolution( spatial_dims, in_channels, n_enc_filter[0], strides=2, act=Act.LEAKYRELU, norm=Norm.BATCH, bias=False, ) encoder_layers.append( nn.Sequential( OrderedDict([( "Conv", init_conv, ), ("skip", Stash().isolate(namespaces[0]))]))) for i in range(1, depth + 1): down_conv = DoubleConv(spatial_dims, n_enc_filter[i - 1], n_enc_filter[i]) if i == depth: layer_dict = OrderedDict([("Down", down_conv)]) else: layer_dict = OrderedDict([("Down", down_conv), ("skip", Stash().isolate(namespaces[i]))]) encoder_layers.append(nn.Sequential(layer_dict)) encoder = nn.Sequential(*encoder_layers) # construct the decoder decoder_layers: List[nn.Module] = [] for i in reversed(range(1, depth + 1)): in_ch, out_ch = n_enc_filter[i], n_enc_filter[i - 1] layer_dict = OrderedDict([ ("Up", UpSample(spatial_dims, in_ch, out_ch, 2, True)), ("skip", PopCat().isolate(namespaces[i - 1])), ("Conv1x1x1", Conv[Conv.CONV, spatial_dims](out_ch * 2, in_ch, kernel_size=1)), ("Conv", DoubleConv(spatial_dims, in_ch, out_ch, stride=1, conv_only=True)), ]) decoder_layers.append(nn.Sequential(layer_dict)) in_ch = min(n_enc_filter[0] // 2, 32) layer_dict = OrderedDict([ ("Up", UpSample(spatial_dims, n_feat, in_ch, 2, True)), ("RELU", Act[Act.LEAKYRELU](inplace=False)), ( "out", Conv[Conv.CONV, spatial_dims](in_ch, out_channels, kernel_size=3, padding=1), ), ]) decoder_layers.append(nn.Sequential(layer_dict)) decoder = nn.Sequential(*decoder_layers) # making a sequential model self.add_module("encoder", encoder) self.add_module("decoder", decoder) for m in self.modules(): if isinstance(m, Conv[Conv.CONV, spatial_dims]): nn.init.kaiming_normal_(m.weight) elif isinstance(m, Norm[Norm.BATCH, spatial_dims]): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, Conv[Conv.CONVTRANS, spatial_dims]): nn.init.kaiming_normal_(m.weight)