def make_vqvae(hps, device='cuda'): from jukebox.vqvae.vqvae import VQVAE block_kwargs = dict( width=hps.width, depth=hps.depth, m_conv=hps.m_conv, dilation_growth_rate=hps.dilation_growth_rate, dilation_cycle=hps.dilation_cycle, reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) if not hps.sample_length: assert hps.sample_length_in_seconds != 0 downsamples = calculate_strides(hps.strides_t, hps.downs_t) top_raw_to_tokens = np.prod(downsamples) hps.sample_length = (hps.sample_length_in_seconds * hps.sr // top_raw_to_tokens) * top_raw_to_tokens print( f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}" ) vqvae = VQVAE(input_shape=(hps.sample_length, 1), levels=hps.levels, downs_t=hps.downs_t, strides_t=hps.strides_t, emb_width=hps.emb_width, l_bins=hps.l_bins, mu=hps.l_mu, commit=hps.commit, spectral=hps.spectral, multispectral=hps.multispectral, multipliers=hps.hvqvae_multipliers, use_bottleneck=hps.use_bottleneck, **block_kwargs) vqvae = vqvae.to(device) restore(hps, vqvae, hps.restore_vqvae) if hps.train and not hps.prior: print_all(f"Loading vqvae in train mode") if hps.restore_vqvae != '': print_all("Reseting bottleneck emas") for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks): num_samples = hps.sample_length downsamples = calculate_strides(hps.strides_t, hps.downs_t) raw_to_tokens = np.prod(downsamples[:level + 1]) num_tokens = (num_samples // raw_to_tokens) * dist.get_world_size() bottleneck.restore_k(num_tokens=num_tokens, threshold=hps.revival_threshold) else: print_all(f"Loading vqvae in eval mode") vqvae.eval() freeze_model(vqvae) return vqvae
def make_prior(hps, vqvae, device='cuda'): from jukebox.prior.prior import SimplePrior prior_kwargs = dict(input_shape=(hps.n_ctx, ), bins=hps.l_bins, width=hps.prior_width, depth=hps.prior_depth, heads=hps.heads, attn_order=hps.attn_order, blocks=hps.blocks, spread=hps.spread, attn_dropout=hps.attn_dropout, resid_dropout=hps.resid_dropout, emb_dropout=hps.emb_dropout, zero_out=hps.zero_out, res_scale=hps.res_scale, pos_init=hps.pos_init, init_scale=hps.init_scale, m_attn=hps.m_attn, m_mlp=hps.m_mlp, checkpoint_res=hps.c_res if hps.train else 0, checkpoint_attn=hps.c_attn if hps.train else 0, checkpoint_mlp=hps.c_mlp if hps.train else 0) x_cond_kwargs = dict( out_width=hps.prior_width, init_scale=hps.init_scale, width=hps.cond_width, depth=hps.cond_depth, m_conv=hps.cond_m_conv, dilation_growth_rate=hps.cond_dilation_growth_rate, dilation_cycle=hps.cond_dilation_cycle, zero_out=hps.cond_zero_out, res_scale=hps.cond_res_scale, checkpoint_res=hps.cond_c_res) # have to keep this else names wrong y_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale, y_bins=hps.y_bins, t_bins=hps.t_bins, t_ranges=hps.t_ranges, max_bow_genre_size=hps.max_bow_genre_size) if hps.use_tokens and not hps.single_enc_dec: prime_kwargs = dict( use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction, n_tokens=hps.n_tokens, bins=hps.n_vocab, width=hps.prime_width, depth=hps.prime_depth, heads=hps.prime_heads, attn_order=hps.prime_attn_order, blocks=hps.prime_blocks, spread=hps.prime_spread, attn_dropout=hps.prime_attn_dropout, resid_dropout=hps.prime_resid_dropout, emb_dropout=hps.prime_emb_dropout, zero_out=hps.prime_zero_out, res_scale=hps.prime_res_scale, pos_init=hps.prime_pos_init, init_scale=hps.prime_init_scale, m_attn=hps.prime_m_attn, m_mlp=hps.prime_m_mlp, checkpoint_res=hps.prime_c_res if hps.train else 0, checkpoint_attn=hps.prime_c_attn if hps.train else 0, checkpoint_mlp=hps.prime_c_mlp if hps.train else 0) else: prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction, n_tokens=hps.n_tokens, bins=hps.n_vocab) # z_shapes for other levels given this level gets n_ctx codes rescale = lambda z_shape: (z_shape[0] * hps.n_ctx // vqvae.z_shapes[ hps.level][0], ) z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes] prior = SimplePrior(z_shapes=z_shapes, l_bins=hps.l_bins, encoder=vqvae.encode, decoder=vqvae.decode, level=hps.level, downs_t=hps.downs_t, strides_t=hps.strides_t, labels=hps.labels, prior_kwargs=prior_kwargs, x_cond_kwargs=x_cond_kwargs, y_cond_kwargs=y_cond_kwargs, prime_kwargs=prime_kwargs, copy_input=hps.copy_input, labels_v3=hps.labels_v3, merged_decoder=hps.merged_decoder, single_enc_dec=hps.single_enc_dec) prior.alignment_head = hps.get('alignment_head', None) prior.alignment_layer = hps.get('alignment_layer', None) if hps.fp16_params: print_all("Converting to fp16 params") from jukebox.transformer.ops import _convert_conv_weights_to_fp16 prior.apply(_convert_conv_weights_to_fp16) prior = prior.to(device) restore(hps, prior, hps.restore_prior) if hps.train: print_all(f"Loading prior in train mode") pass else: print_all(f"Loading prior in eval mode") prior.eval() freeze_model(prior) return prior