Пример #1
0
def run(**kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    model = "1b_lyrics"
    port = 29500
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams()
    hps.sr = 44100
    hps.n_samples = 1
    hps.name = kwargs["sample_name"]
    chunk_size = 32
    max_batch_size = 16
    hps.levels = 3
    hps.hop_fraction = [.5,.5,.125]

    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
    top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

    sample_length_in_seconds = kwargs["sample_length"]
    hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
    assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

    metas = [dict(
                artist = kwargs["artist"],
                genre = kwargs["genre"],
                total_length = hps.sample_length,
                offset = 0,
                lyrics = kwargs["lyrics"],
            ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]

    sampling_temperature = .98

    lower_batch_size = 16
    max_batch_size = 16
    lower_level_chunk_size = 32
    chunk_size =  32
    sampling_kwargs = [
        dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=sampling_temperature, fp16=True, 
            max_batch_size=max_batch_size, chunk_size=chunk_size)
    ]

    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)

    del top_prior
    empty_cache()
    top_prior=None
    upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
    labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]

    zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
Пример #2
0
# Specifies the directory to save the sample in.
# We set this to the Google Drive mount point.

if len(sys.argv) > 1:
    this_run_slug = sys.argv[1]
else:
    this_run_slug = "co_compose_synth2"

hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/'

meta = pickle.load(open(f'{hps.name}meta.p', "rb"))

hps.sample_length = 1048576 if model == "5b_lyrics" else 786432
chunk_size = 16 if model == "5b_lyrics" else 32
max_batch_size = 3 if model == "5b_lyrics" else 16
hps.hop_fraction = [.5, .5, .125]
hps.levels = 3

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length=hps.sample_length)),
                   device)

metas_1 = meta[0]
metas_2 = meta[1]

print(metas_1)
print(metas_2)

zs = t.load(f'{hps.name}zs-top-level-final.t')

top_prior_raw_to_tokens = 128