示例#1
0
def make_vqvae(hps, device='cuda'):
    from jukebox.vqvae.vqvae import VQVAE
    block_kwargs = dict(
        width=hps.width,
        depth=hps.depth,
        m_conv=hps.m_conv,
        dilation_growth_rate=hps.dilation_growth_rate,
        dilation_cycle=hps.dilation_cycle,
        reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)

    if not hps.sample_length:
        assert hps.sample_length_in_seconds != 0
        downsamples = calculate_strides(hps.strides_t, hps.downs_t)
        top_raw_to_tokens = np.prod(downsamples)
        hps.sample_length = (hps.sample_length_in_seconds * hps.sr //
                             top_raw_to_tokens) * top_raw_to_tokens
        print(
            f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}"
        )

    vqvae = VQVAE(input_shape=(hps.sample_length, 1),
                  levels=hps.levels,
                  downs_t=hps.downs_t,
                  strides_t=hps.strides_t,
                  emb_width=hps.emb_width,
                  l_bins=hps.l_bins,
                  mu=hps.l_mu,
                  commit=hps.commit,
                  spectral=hps.spectral,
                  multispectral=hps.multispectral,
                  multipliers=hps.hvqvae_multipliers,
                  use_bottleneck=hps.use_bottleneck,
                  **block_kwargs)

    vqvae = vqvae.to(device)
    restore(hps, vqvae, hps.restore_vqvae)
    if hps.train and not hps.prior:
        print_all(f"Loading vqvae in train mode")
        if hps.restore_vqvae != '':
            print_all("Reseting bottleneck emas")
            for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
                num_samples = hps.sample_length
                downsamples = calculate_strides(hps.strides_t, hps.downs_t)
                raw_to_tokens = np.prod(downsamples[:level + 1])
                num_tokens = (num_samples //
                              raw_to_tokens) * dist.get_world_size()
                bottleneck.restore_k(num_tokens=num_tokens,
                                     threshold=hps.revival_threshold)
    else:
        print_all(f"Loading vqvae in eval mode")
        vqvae.eval()
        freeze_model(vqvae)
    return vqvae
示例#2
0
def make_prior(hps, vqvae, device='cuda'):
    from jukebox.prior.prior import SimplePrior

    prior_kwargs = dict(input_shape=(hps.n_ctx, ),
                        bins=hps.l_bins,
                        width=hps.prior_width,
                        depth=hps.prior_depth,
                        heads=hps.heads,
                        attn_order=hps.attn_order,
                        blocks=hps.blocks,
                        spread=hps.spread,
                        attn_dropout=hps.attn_dropout,
                        resid_dropout=hps.resid_dropout,
                        emb_dropout=hps.emb_dropout,
                        zero_out=hps.zero_out,
                        res_scale=hps.res_scale,
                        pos_init=hps.pos_init,
                        init_scale=hps.init_scale,
                        m_attn=hps.m_attn,
                        m_mlp=hps.m_mlp,
                        checkpoint_res=hps.c_res if hps.train else 0,
                        checkpoint_attn=hps.c_attn if hps.train else 0,
                        checkpoint_mlp=hps.c_mlp if hps.train else 0)

    x_cond_kwargs = dict(
        out_width=hps.prior_width,
        init_scale=hps.init_scale,
        width=hps.cond_width,
        depth=hps.cond_depth,
        m_conv=hps.cond_m_conv,
        dilation_growth_rate=hps.cond_dilation_growth_rate,
        dilation_cycle=hps.cond_dilation_cycle,
        zero_out=hps.cond_zero_out,
        res_scale=hps.cond_res_scale,
        checkpoint_res=hps.cond_c_res)  # have to keep this else names wrong
    y_cond_kwargs = dict(out_width=hps.prior_width,
                         init_scale=hps.init_scale,
                         y_bins=hps.y_bins,
                         t_bins=hps.t_bins,
                         t_ranges=hps.t_ranges,
                         max_bow_genre_size=hps.max_bow_genre_size)

    if hps.use_tokens and not hps.single_enc_dec:
        prime_kwargs = dict(
            use_tokens=hps.use_tokens,
            prime_loss_fraction=hps.prime_loss_fraction,
            n_tokens=hps.n_tokens,
            bins=hps.n_vocab,
            width=hps.prime_width,
            depth=hps.prime_depth,
            heads=hps.prime_heads,
            attn_order=hps.prime_attn_order,
            blocks=hps.prime_blocks,
            spread=hps.prime_spread,
            attn_dropout=hps.prime_attn_dropout,
            resid_dropout=hps.prime_resid_dropout,
            emb_dropout=hps.prime_emb_dropout,
            zero_out=hps.prime_zero_out,
            res_scale=hps.prime_res_scale,
            pos_init=hps.prime_pos_init,
            init_scale=hps.prime_init_scale,
            m_attn=hps.prime_m_attn,
            m_mlp=hps.prime_m_mlp,
            checkpoint_res=hps.prime_c_res if hps.train else 0,
            checkpoint_attn=hps.prime_c_attn if hps.train else 0,
            checkpoint_mlp=hps.prime_c_mlp if hps.train else 0)
    else:
        prime_kwargs = dict(use_tokens=hps.use_tokens,
                            prime_loss_fraction=hps.prime_loss_fraction,
                            n_tokens=hps.n_tokens,
                            bins=hps.n_vocab)

    # z_shapes for other levels given this level gets n_ctx codes
    rescale = lambda z_shape: (z_shape[0] * hps.n_ctx // vqvae.z_shapes[
        hps.level][0], )
    z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes]

    prior = SimplePrior(z_shapes=z_shapes,
                        l_bins=hps.l_bins,
                        encoder=vqvae.encode,
                        decoder=vqvae.decode,
                        level=hps.level,
                        downs_t=hps.downs_t,
                        strides_t=hps.strides_t,
                        labels=hps.labels,
                        prior_kwargs=prior_kwargs,
                        x_cond_kwargs=x_cond_kwargs,
                        y_cond_kwargs=y_cond_kwargs,
                        prime_kwargs=prime_kwargs,
                        copy_input=hps.copy_input,
                        labels_v3=hps.labels_v3,
                        merged_decoder=hps.merged_decoder,
                        single_enc_dec=hps.single_enc_dec)

    prior.alignment_head = hps.get('alignment_head', None)
    prior.alignment_layer = hps.get('alignment_layer', None)

    if hps.fp16_params:
        print_all("Converting to fp16 params")
        from jukebox.transformer.ops import _convert_conv_weights_to_fp16
        prior.apply(_convert_conv_weights_to_fp16)
    prior = prior.to(device)
    restore(hps, prior, hps.restore_prior)
    if hps.train:
        print_all(f"Loading prior in train mode")
        pass
    else:
        print_all(f"Loading prior in eval mode")
        prior.eval()
        freeze_model(prior)
    return prior