Exemple #1
0
def norm_block(is_layer_norm, dim, affine=True):
    if is_layer_norm:
        mod = nn.Sequential(
            TransposeLast(),
            Fp32LayerNorm(dim, elementwise_affine=affine),
            TransposeLast(),
        )
    else:
        mod = Fp32GroupNorm(1, dim, affine=affine)

    return mod
Exemple #2
0
        def block(
            n_in,
            n_out,
            k,
            stride,
            padding,
            is_layer_norm=False,
            is_group_norm=False,
            conv_bias=False,
        ):
            def make_conv():
                assert len(k) == 2
                conv = nn.Conv2d(n_in,
                                 n_out,
                                 k,
                                 stride=stride,
                                 bias=conv_bias,
                                 padding=padding)
                nn.init.kaiming_normal_(conv.weight)
                return conv

            assert (is_layer_norm and is_group_norm
                    ) == False, "layer norm and group norm are exclusive"

            if is_layer_norm:
                assert False  # JCh: didn't check teh shapes
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        TransposeLast(),
                        Fp32LayerNorm(dim, elementwise_affine=True),
                        TransposeLast(),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    Fp32GroupNorm(dim, dim, affine=True),
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout),
                                     nn.GELU())
Exemple #3
0
def conv_1d_block(n_in,
                  n_out,
                  k,
                  stride,
                  dropout=0.0,
                  is_layer_norm=False,
                  is_group_norm=False,
                  conv_bias=False):
    def make_conv():
        conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
        nn.init.kaiming_normal_(conv.weight)
        return conv

    assert (is_layer_norm and
            is_group_norm) == False, "layer norm and group norm are exclusive"

    if is_layer_norm:
        return nn.Sequential(
            make_conv(),
            nn.Dropout(p=dropout),
            nn.Sequential(
                TransposeLast(),
                Fp32LayerNorm(n_out, elementwise_affine=True),
                TransposeLast(),
            ),
            nn.GELU(),
        )
    elif is_group_norm:
        return nn.Sequential(
            make_conv(),
            nn.Dropout(p=dropout),
            Fp32GroupNorm(n_out, n_out, affine=True),
            nn.GELU(),
        )
    else:
        return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())