Exemplo n.º 1
0
def WideResnet(block_size,
               k,
               num_classes,
               channels=1024,
               nonlinearity='relu',
               parameterization='standard',
               order=None):
    if nonlinearity == 'relu':
        nonlin = Relu
    elif nonlinearity == 'swish':
        nonlin = Swish
    return jax_stax.serial(
        MyConv(channels, (3, 3),
               padding='SAME',
               parameterization=parameterization,
               order=order),
        WideResnetGroup(block_size,
                        int(16 * k),
                        nonlin=nonlin,
                        parameterization=parameterization,
                        order=order),
        WideResnetGroup(block_size,
                        int(32 * k), (2, 2),
                        nonlin=nonlin,
                        parameterization=parameterization,
                        order=order),
        WideResnetGroup(block_size,
                        int(64 * k), (2, 2),
                        nonlin=nonlin,
                        parameterization=parameterization,
                        order=order), AvgPool((8, 8)), Flatten,
        MyDense(num_classes, parameterization=parameterization, order=order))
Exemplo n.º 2
0
def ResNet50(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]),
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        AvgPool((7, 7)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Exemplo n.º 3
0
def LeNet5(num_classes):
    return stax.serial(
        GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Conv(16, (5,5), strides = (1,1),padding="SAME"),
        BatchNorm(),
        Relu,
        AvgPool((3,3)),

        Flatten,
        Dense(num_classes*10),
        Dense(num_classes*5),
        Dense(num_classes),
        LogSoftmax
    )
Exemplo n.º 4
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)),
        convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]),
        identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]),
        identityBlock(3, [128, 128]), identityBlock(3, [128, 128]),
        identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), identityBlock(3, [256, 256]),
        identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]),
        identityBlock(3, [512, 512]), identityBlock(3, [512, 512]),
        AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
Exemplo n.º 5
0
def ResNet(num_classes):
    return stax.serial(
        GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"),
        BatchNorm(),
        Relu,
        MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [4, 4, 4], strides=(1, 1)),
        IdentityBlock(3, [4, 4]),
        AvgPool((3, 3)),
        Flatten,
        Dense(num_classes),
        LogSoftmax,
    )
Exemplo n.º 6
0
def ResNet(block,
           expansion,
           layers,
           normalization_method=None,
           width_per_group=64,
           actfn=stax.Relu):
    norm_layer = Identity
    if normalization_method == "group_norm":
        norm_layer = GroupNorm(32)
    elif normalization_method == "batch_norm":
        norm_layer = BatchNorm()
    base_width = width_per_group

    def _make_layer(block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = stax.serial(
                Conv(planes * expansion, (1, 1),
                     strides=(stride, stride),
                     bias=False),
                norm_layer,
            )
        layers = []
        layers.append(block(planes, stride, downsample, base_width,
                            norm_layer))
        for _ in range(1, blocks):
            layers.append(
                block(planes,
                      base_width=base_width,
                      norm_layer=norm_layer,
                      actfn=actfn))
        return stax.serial(*layers)

    return [
        Conv(64, (3, 3), strides=(1, 1), padding="SAME", bias=False),
        norm_layer,
        actfn,
        # MaxPool((3, 3), strides=(2, 2), padding="SAME"),
        _make_layer(block, 64, layers[0]),
        _make_layer(block, 128, layers[1], stride=2),
        _make_layer(block, 256, layers[2], stride=2),
        _make_layer(block, 512, layers[3], stride=2),
        AvgPool((4, 4)),
        Flatten,
    ]