def run(**kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi model = "1b_lyrics" port = 29500 rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams() hps.sr = 44100 hps.n_samples = 1 hps.name = kwargs["sample_name"] chunk_size = 32 max_batch_size = 16 hps.levels = 3 hps.hop_fraction = [.5,.5,.125] vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device) top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device) sample_length_in_seconds = kwargs["sample_length"] hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate' metas = [dict( artist = kwargs["artist"], genre = kwargs["genre"], total_length = hps.sample_length, offset = 0, lyrics = kwargs["lyrics"], ), ] * hps.n_samples labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')] sampling_temperature = .98 lower_batch_size = 16 max_batch_size = 16 lower_level_chunk_size = 32 chunk_size = 32 sampling_kwargs = [ dict(temp=.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=sampling_temperature, fp16=True, max_batch_size=max_batch_size, chunk_size=chunk_size) ] zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))] zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps) del top_prior empty_cache() top_prior=None upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]] labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
from IPython.display import Audio from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model from jukebox.hparams import Hyperparams, setup_hparams from jukebox.sample import sample_single_window, _sample, \ sample_partial_window, upsample, \ load_prompts from jukebox.utils.dist_utils import setup_dist_from_mpi from jukebox.utils.torch_utils import empty_cache port = random.randint(10000, 20000) rank, local_rank, device = setup_dist_from_mpi(port=port) model = "5b_lyrics" # or "1b_lyrics" hps = Hyperparams() hps.sr = 44100 hps.n_samples = 3 if model == '5b_lyrics' else 16 # Specifies the directory to save the sample in. # We set this to the Google Drive mount point. if len(sys.argv) > 1: this_run_slug = sys.argv[1] else: this_run_slug = "co_compose_synth2" hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/' meta = pickle.load(open(f'{hps.name}meta.p', "rb")) hps.sample_length = 1048576 if model == "5b_lyrics" else 786432 chunk_size = 16 if model == "5b_lyrics" else 32 max_batch_size = 3 if model == "5b_lyrics" else 16