def __init__(self, n_in, n_ctx, n_head, attn_dropout=0.0, resid_dropout=0.0, afn='quick_gelu', scale=True, mask=False, zero_out=False, init_scale=1.0, res_scale=1.0, m_attn = 0.25, m_mlp = 1., checkpoint_attn = 0, checkpoint_mlp = 0, attn_func=0, blocks=None, spread=None, encoder_dims=None, prime_len=None): super().__init__() self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head, attn_dropout=attn_dropout, resid_dropout=resid_dropout, scale=scale, mask=mask, zero_out=zero_out, init_scale=init_scale, checkpoint_attn=checkpoint_attn, attn_func=attn_func, blocks=blocks, spread=spread, encoder_dims=encoder_dims, prime_len=prime_len) self.ln_0 = LayerNorm(n_in) self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in), resid_dropout=resid_dropout, afn=afn, zero_out=zero_out, init_scale=init_scale) self.ln_1 = LayerNorm(n_in) self.res_scale = res_scale self.checkpoint_attn = checkpoint_attn self.checkpoint_mlp = checkpoint_mlp self.n_in = n_in self.attn_func = attn_func
class ResAttnBlock(nn.Module): def __init__(self, n_in, n_ctx, n_head, attn_dropout=0.0, resid_dropout=0.0, afn='quick_gelu', scale=True, mask=False, zero_out=False, init_scale=1.0, res_scale=1.0, m_attn = 0.25, m_mlp = 1., checkpoint_attn = 0, checkpoint_mlp = 0, attn_func=0, blocks=None, spread=None, encoder_dims=None, prime_len=None): super().__init__() self.attn = FactoredAttention(n_in=n_in, n_ctx=n_ctx, n_state=int(m_attn * n_in), n_head=n_head, attn_dropout=attn_dropout, resid_dropout=resid_dropout, scale=scale, mask=mask, zero_out=zero_out, init_scale=init_scale, checkpoint_attn=checkpoint_attn, attn_func=attn_func, blocks=blocks, spread=spread, encoder_dims=encoder_dims, prime_len=prime_len) self.ln_0 = LayerNorm(n_in) self.mlp = MLP(n_in=n_in, n_state=int(m_mlp * n_in), resid_dropout=resid_dropout, afn=afn, zero_out=zero_out, init_scale=init_scale) self.ln_1 = LayerNorm(n_in) self.res_scale = res_scale self.checkpoint_attn = checkpoint_attn self.checkpoint_mlp = checkpoint_mlp self.n_in = n_in self.attn_func = attn_func def forward(self, x, encoder_kv, sample=False): if sample: a = self.attn(self.ln_0(x), encoder_kv, sample) m = self.mlp(self.ln_1(x + a)) else: if self.attn_func == 6: assert encoder_kv is not None a = checkpoint(lambda _x,_enc_kv,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s), (x,encoder_kv), (*self.attn.parameters(), *self.ln_0.parameters()), self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting. else: assert encoder_kv is None a = checkpoint(lambda _x,_enc_kv=None,_s=sample: self.attn(self.ln_0(_x),_enc_kv,_s), (x,), (*self.attn.parameters(), *self.ln_0.parameters()), self.checkpoint_attn == 3) # 2 recomputes after the projections, and 1 recomputes after head splitting. m = checkpoint(lambda _x: self.mlp(self.ln_1(_x)), (x + a,), (*self.mlp.parameters(), *self.ln_1.parameters()), self.checkpoint_mlp == 1) if self.res_scale == 1.0: h = x + a + m else: h = x + self.res_scale * (a + m) return h
def __init__(self, input_shape, bins, down_t, stride_t, out_width, init_scale, zero_out, res_scale, **block_kwargs): super().__init__() self.x_shape = input_shape # Embedding self.width = out_width self.x_emb = nn.Embedding(bins, out_width) nn.init.normal_(self.x_emb.weight, std=0.02 * init_scale) # Conditioner self.cond = DecoderConvBock(self.width, self.width, down_t, stride_t, **block_kwargs, zero_out=zero_out, res_scale=res_scale) self.ln = LayerNorm(self.width)
def __init__(self, z_shapes, l_bins, encoder, decoder, level, downs_t, strides_t, labels, prior_kwargs, x_cond_kwargs, y_cond_kwargs, prime_kwargs, copy_input, labels_v3=False, merged_decoder=False, single_enc_dec=False): super().__init__() self.use_tokens = prime_kwargs.pop('use_tokens') self.n_tokens = prime_kwargs.pop('n_tokens') self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction') self.copy_input = copy_input if self.copy_input: prime_kwargs['bins'] = l_bins self.z_shapes = z_shapes self.levels = len(self.z_shapes) self.z_shape = self.z_shapes[level] self.level = level assert level < self.levels, f"Total levels {self.levels}, got level {level}" self.l_bins = l_bins # Passing functions instead of the vqvae module to avoid getting params self.encoder = encoder self.decoder = decoder # X conditioning self.x_cond = (level != (self.levels - 1)) self.cond_level = level + 1 # Y conditioning self.y_cond = labels self.single_enc_dec = single_enc_dec # X conditioning if self.x_cond: self.conditioner_blocks = nn.ModuleList() conditioner_block = lambda _level: Conditioner( input_shape=z_shapes[_level], bins=l_bins, down_t=downs_t[_level], stride_t=strides_t[_level], **x_cond_kwargs) if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)") self.conditioner_blocks.append(conditioner_block(self.cond_level)) # Y conditioning if self.y_cond: self.n_time = self.z_shape[ 0] # Assuming STFT=TF order and raw=T1 order, so T is first dim self.y_emb = LabelConditioner(n_time=self.n_time, include_time_signal=not self.x_cond, **y_cond_kwargs) # Lyric conditioning if single_enc_dec: # Single encoder-decoder transformer self.prior_shapes = [(self.n_tokens, ), prior_kwargs.pop('input_shape')] self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')] self.prior_dims = [np.prod(shape) for shape in self.prior_shapes] self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1] self.prior_width = prior_kwargs['width'] print_once( f'Creating cond. autoregress with prior bins {self.prior_bins}, ' ) print_once(f'dims {self.prior_dims}, ') print_once(f'shift {self.prior_bins_shift}') print_once(f'input shape {sum(self.prior_dims)}') print_once(f'input bins {sum(self.prior_bins)}') print_once(f'Self copy is {self.copy_input}') self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[ 0], self.prior_dims[1] self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims self.prior = ConditionalAutoregressive2D( input_shape=(sum(self.prior_dims), ), bins=sum(self.prior_bins), x_cond=(self.x_cond or self.y_cond), y_cond=True, prime_len=self.prime_loss_dims, **prior_kwargs) else: # Separate encoder-decoder transformer if self.n_tokens != 0 and self.use_tokens: from jukebox.transformer.ops import Conv1D prime_input_shape = (self.n_tokens, ) self.prime_loss_dims = np.prod(prime_input_shape) self.prime_acts_width, self.prime_state_width = prime_kwargs[ 'width'], prior_kwargs['width'] self.prime_prior = ConditionalAutoregressive2D( input_shape=prime_input_shape, x_cond=False, y_cond=False, only_encode=True, **prime_kwargs) self.prime_state_proj = Conv1D( self.prime_acts_width, self.prime_state_width, init_scale=prime_kwargs['init_scale']) self.prime_state_ln = LayerNorm(self.prime_state_width) self.prime_bins = prime_kwargs['bins'] self.prime_x_out = nn.Linear(self.prime_state_width, self.prime_bins, bias=False) nn.init.normal_(self.prime_x_out.weight, std=0.02 * prior_kwargs['init_scale']) else: self.prime_loss_dims = 0 self.gen_loss_dims = np.prod(self.z_shape) self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims self.prior = ConditionalAutoregressive2D( x_cond=(self.x_cond or self.y_cond), y_cond=self.y_cond, encoder_dims=self.prime_loss_dims, merged_decoder=merged_decoder, **prior_kwargs) self.n_ctx = self.gen_loss_dims self.downsamples = calculate_strides(strides_t, downs_t) self.cond_downsample = self.downsamples[ level + 1] if level != self.levels - 1 else None self.raw_to_tokens = np.prod(self.downsamples[:level + 1]) self.sample_length = self.n_ctx * self.raw_to_tokens if labels: self.labels_v3 = labels_v3 self.labeller = Labeller(self.y_emb.max_bow_genre_size, self.n_tokens, self.sample_length, v3=self.labels_v3) print( f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}" )