def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False): def make_conv(): conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) nn.init.kaiming_normal_(conv.weight) return conv assert ( is_layer_norm and is_group_norm ) == False, "layer norm and group norm are exclusive" if is_layer_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), nn.Sequential( TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast(), ), nn.GELU(), ) elif is_group_norm: return nn.Sequential( make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU(), ) else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
def norm_block(is_layer_norm, dim, affine=True): if is_layer_norm: mod = nn.Sequential( TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=affine), TransposeLast(), ) else: mod = Fp32GroupNorm(1, dim, affine=affine) return mod
def __init__( self, dim, num_vars, groups, combine_groups, vq_dim, time_first, gamma=0.25 ): '''Vector quantization using straight pass-through estimator (i.e. kmeans) Args: dim: input dimension (channels) num_vars: number of quantized vectors per group groups: number of groups for vector quantization combine_groups: whether to use the vectors for all groups vq_dim: dimensionality of the resulting quantized vector time_first: if true, expect input in BxTxC format, otherwise in BxCxT gamma: commitment loss coefficient ''' super().__init__() self.groups = groups self.combine_groups = combine_groups self.input_dim = dim self.num_vars = num_vars self.vq_dim = vq_dim self.time_first = time_first assert ( vq_dim % groups == 0 ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" self.var_dim = vq_dim // groups num_groups = groups if not combine_groups else 1 self.embedding = nn.Parameter( 0.01 * torch.randn(num_vars, num_groups, self.var_dim) ) self.projection = nn.Sequential( nn.Conv1d(dim, dim, kernel_size=1, groups=groups, bias=False), Fp32GroupNorm(groups, dim), ) self.gamma = gamma self.mse_mean = nn.MSELoss(reduction="mean")