Esempio n. 1
0
    def identity_block(inputs):
        main = Sequential(Conv(filters1, (1, 1)), BatchNorm(), relu,
                          Conv(filters2, (ks, ks), padding='SAME'),
                          BatchNorm(), relu, Conv(inputs.shape[3], (1, 1)),
                          BatchNorm())

        return relu(sum((main(inputs), inputs)))
Esempio n. 2
0
 def conv_block(inputs):
     main = Sequential(Conv(filters1, (1, 1), strides), BatchNorm(), relu,
                       Conv(filters2, (ks, ks),
                            padding='SAME'), BatchNorm(), relu,
                       Conv(filters3, (1, 1)), BatchNorm())
     shortcut = Sequential(Conv(filters3, (1, 1), strides), BatchNorm())
     return relu(sum((main(inputs), shortcut(inputs))))
Esempio n. 3
0
def test_BatchNorm_shape_NCHW(center, scale):
    input_shape = (4, 5, 6, 7)
    batch_norm = BatchNorm(axis=(0, 2, 3), center=center, scale=scale)

    inputs = random_inputs(input_shape)
    params = batch_norm.init_parameters(PRNGKey(0), inputs)
    out = batch_norm.apply(params, inputs)

    assert out.shape == input_shape
    if center:
        assert params.beta.shape == (5, )
    if scale:
        assert params.gamma.shape == (5, )
Esempio n. 4
0
def ResNet50(num_classes):
    return Sequential(
        GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'),
        BatchNorm(), relu, MaxPool((3, 3), strides=(2, 2)),
        ConvBlock(3, [64, 64, 256], strides=(1, 1)),
        IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]),
        ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]),
        ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]),
        ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]), AvgPool((7, 7)), flatten,
        Dense(num_classes), logsoftmax)