Example #1
0
def CifarBasicBlockv2(planes,
                      stride=1,
                      option="A",
                      normalization_method=None,
                      use_fixup=False,
                      num_layers=None,
                      w_init=None,
                      actfn=stax.Relu):
    assert not use_fixup, "nah"
    Main = stax.serial(
        maybe_use_normalization(normalization_method),
        actfn,
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             W_init=w_init,
             bias=False),
        maybe_use_normalization(normalization_method),
        actfn,
        Conv(planes, (3, 3), padding="SAME", W_init=w_init, bias=False),
    )
    Shortcut = Identity
    if stride > 1:
        if option == "A":
            # For CIFAR10 ResNet paper uses option A.
            Shortcut = LambdaLayer(_shortcut_pad)
        elif option == "B":
            Shortcut = Conv(planes, (1, 1),
                            strides=(stride, stride),
                            W_init=w_init,
                            bias=False)
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum)
Example #2
0
def CifarBasicBlock(planes,
                    stride=1,
                    option="A",
                    normalization_method=None,
                    use_fixup=False,
                    num_layers=None,
                    w_init=None,
                    actfn=stax.Relu):
    Main = stax.serial(
        FixupBias() if use_fixup else Identity,
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             W_init=fixup_init(num_layers) if use_fixup else w_init,
             bias=False),
        maybe_use_normalization(normalization_method),
        FixupBias() if use_fixup else Identity,
        actfn,
        FixupBias() if use_fixup else Identity,
        Conv(planes, (3, 3),
             padding="SAME",
             bias=False,
             W_init=zeros if use_fixup else w_init),
        maybe_use_normalization(normalization_method),
        FixupScale() if use_fixup else Identity,
        FixupBias() if use_fixup else Identity,
    )
    Shortcut = Identity
    if stride > 1:
        if option == "A":
            # For CIFAR10 ResNet paper uses option A.
            Shortcut = stax.serial(
                #  FixupBiast() if use_fixup else Identity,
                LambdaLayer(_shortcut_pad))
        elif option == "B":
            Shortcut = stax.serial(
                FixupBias() if use_fixup else Identity,
                Conv(planes, (1, 1),
                     strides=(stride, stride),
                     W_init=w_init,
                     bias=False),
                maybe_use_normalization(normalization_method))
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       actfn)
Example #3
0
def BasicBlock(planes,
               stride=1,
               downsample=None,
               base_width=64,
               norm_layer=Identity,
               actfn=stax.Relu):
    if base_width != 64:
        raise ValueError("BasicBlock only supports base_width=64")
    Main = stax.serial(
        Conv(planes, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             bias=False),
        norm_layer,
        actfn,
        Conv(planes, (3, 3), padding="SAME", bias=False),
        norm_layer,
    )
    Shortcut = downsample if downsample is not None else Identity
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       actfn)
Example #4
0
def BottleneckBlock(planes,
                    stride=1,
                    downsample=None,
                    base_width=64,
                    norm_layer=Identity,
                    actfn=stax.Relu):
    width = int(planes * (base_width / 64.))
    Main = stax.serial(
        Conv(width, (1, 1), bias=False),
        norm_layer,
        actfn,
        Conv(width, (3, 3),
             strides=(stride, stride),
             padding="SAME",
             bias=False),
        norm_layer,
        actfn,
        Conv(planes * 4, (1, 1), bias=False),
        norm_layer,
    )
    Shortcut = downsample if downsample is not None else Identity
    return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum,
                       actfn)
Example #5
0
def ResNet(block,
           expansion,
           layers,
           normalization_method=None,
           width_per_group=64,
           actfn=stax.Relu):
    norm_layer = Identity
    if normalization_method == "group_norm":
        norm_layer = GroupNorm(32)
    elif normalization_method == "batch_norm":
        norm_layer = BatchNorm()
    base_width = width_per_group

    def _make_layer(block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = stax.serial(
                Conv(planes * expansion, (1, 1),
                     strides=(stride, stride),
                     bias=False),
                norm_layer,
            )
        layers = []
        layers.append(block(planes, stride, downsample, base_width,
                            norm_layer))
        for _ in range(1, blocks):
            layers.append(
                block(planes,
                      base_width=base_width,
                      norm_layer=norm_layer,
                      actfn=actfn))
        return stax.serial(*layers)

    return [
        Conv(64, (3, 3), strides=(1, 1), padding="SAME", bias=False),
        norm_layer,
        actfn,
        # MaxPool((3, 3), strides=(2, 2), padding="SAME"),
        _make_layer(block, 64, layers[0]),
        _make_layer(block, 128, layers[1], stride=2),
        _make_layer(block, 256, layers[2], stride=2),
        _make_layer(block, 512, layers[3], stride=2),
        AvgPool((4, 4)),
        Flatten,
    ]
Example #6
0
 def _make_layer(block, planes, blocks, stride=1):
     downsample = None
     if stride != 1:
         downsample = stax.serial(
             Conv(planes * expansion, (1, 1),
                  strides=(stride, stride),
                  bias=False),
             norm_layer,
         )
     layers = []
     layers.append(block(planes, stride, downsample, base_width,
                         norm_layer))
     for _ in range(1, blocks):
         layers.append(
             block(planes,
                   base_width=base_width,
                   norm_layer=norm_layer,
                   actfn=actfn))
     return stax.serial(*layers)
Example #7
0
def CifarResNet(block,
                num_blocks,
                expansion=1,
                option="A",
                normalization_method=None,
                use_fixup=False,
                init=None,
                actfn=stax.Relu):
    w_init = None
    if init == "he":
        w_init = partial(variance_scaling, 2.0, "fan_out",
                         "truncated_normal")()
    num_layers = sum(num_blocks)

    def _make_layer(block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(
                block(planes,
                      stride,
                      option=option,
                      normalization_method=normalization_method,
                      use_fixup=use_fixup,
                      num_layers=num_layers,
                      w_init=w_init,
                      actfn=actfn))
        return stax.serial(*layers)

    return [
        Conv(16 * expansion, (3, 3), padding="SAME", W_init=w_init,
             bias=False),
        maybe_use_normalization(normalization_method),
        FixupBias() if use_fixup else Identity,
        actfn,
        _make_layer(block, 16 * expansion, num_blocks[0], stride=1),
        _make_layer(block, 32 * expansion, num_blocks[1], stride=2),
        _make_layer(block, 64 * expansion, num_blocks[2], stride=2),
        AvgPoolAll(),
        Flatten,
        FixupBias() if use_fixup else Identity,
    ]