示例#1
0
def norm_block(is_layer_norm, dim, affine=True):
    if is_layer_norm:
        mod = nn.Sequential(
            TransposeLast(),
            Fp32LayerNorm(dim, elementwise_affine=affine),
            TransposeLast(),
        )
    else:
        mod = Fp32GroupNorm(1, dim, affine=affine)

    return mod
示例#2
0
        def block(
            n_in,
            n_out,
            k,
            stride,
            padding,
            is_layer_norm=False,
            is_group_norm=False,
            conv_bias=False,
        ):
            def make_conv():
                assert len(k) == 2
                conv = nn.Conv2d(n_in,
                                 n_out,
                                 k,
                                 stride=stride,
                                 bias=conv_bias,
                                 padding=padding)
                nn.init.kaiming_normal_(conv.weight)
                return conv

            assert (is_layer_norm and is_group_norm
                    ) == False, "layer norm and group norm are exclusive"

            if is_layer_norm:
                assert False  # JCh: didn't check teh shapes
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    nn.Sequential(
                        TransposeLast(),
                        Fp32LayerNorm(dim, elementwise_affine=True),
                        TransposeLast(),
                    ),
                    nn.GELU(),
                )
            elif is_group_norm:
                return nn.Sequential(
                    make_conv(),
                    nn.Dropout(p=dropout),
                    Fp32GroupNorm(dim, dim, affine=True),
                    nn.GELU(),
                )
            else:
                return nn.Sequential(make_conv(), nn.Dropout(p=dropout),
                                     nn.GELU())
    def __init__(self,
                 dim,
                 num_vars,
                 groups,
                 combine_groups,
                 vq_dim,
                 time_first,
                 gamma=0.25):
        """Vector quantization using straight pass-through estimator (i.e. kmeans)

        Args:
            dim: input dimension (channels)
            num_vars: number of quantized vectors per group
            groups: number of groups for vector quantization
            combine_groups: whether to use the vectors for all groups
            vq_dim: dimensionality of the resulting quantized vector
            time_first: if true, expect input in BxTxC format, otherwise in BxCxT
            gamma: commitment loss coefficient
        """
        super().__init__()

        self.groups = groups
        self.combine_groups = combine_groups
        self.input_dim = dim
        self.num_vars = num_vars
        self.vq_dim = vq_dim
        self.time_first = time_first

        assert (
            vq_dim % groups == 0
        ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation"

        self.var_dim = vq_dim // groups
        num_groups = groups if not combine_groups else 1

        self.embedding = nn.Parameter(
            0.01 * torch.randn(num_vars, num_groups, self.var_dim))
        self.projection = nn.Sequential(
            nn.Conv1d(dim, dim, kernel_size=1, groups=groups, bias=False),
            Fp32GroupNorm(groups, dim),
        )
        self.gamma = gamma
        self.mse_mean = nn.MSELoss(reduction="mean")
示例#4
0
def conv_1d_block(n_in,
                  n_out,
                  k,
                  stride,
                  dropout=0.0,
                  is_layer_norm=False,
                  is_group_norm=False,
                  conv_bias=False):
    def make_conv():
        conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
        nn.init.kaiming_normal_(conv.weight)
        return conv

    assert (is_layer_norm and
            is_group_norm) == False, "layer norm and group norm are exclusive"

    if is_layer_norm:
        return nn.Sequential(
            make_conv(),
            nn.Dropout(p=dropout),
            nn.Sequential(
                TransposeLast(),
                Fp32LayerNorm(n_out, elementwise_affine=True),
                TransposeLast(),
            ),
            nn.GELU(),
        )
    elif is_group_norm:
        return nn.Sequential(
            make_conv(),
            nn.Dropout(p=dropout),
            Fp32GroupNorm(n_out, n_out, affine=True),
            nn.GELU(),
        )
    else:
        return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())