def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def run(**kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi model = "1b_lyrics" port = 29500 rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams() hps.sr = 44100 hps.n_samples = 1 hps.name = kwargs["sample_name"] chunk_size = 32 max_batch_size = 16 hps.levels = 3 hps.hop_fraction = [.5,.5,.125] vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device) top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device) sample_length_in_seconds = kwargs["sample_length"] hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate' metas = [dict( artist = kwargs["artist"], genre = kwargs["genre"], total_length = hps.sample_length, offset = 0, lyrics = kwargs["lyrics"], ), ] * hps.n_samples labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')] sampling_temperature = .98 lower_batch_size = 16 max_batch_size = 16 lower_level_chunk_size = 32 chunk_size = 32 sampling_kwargs = [ dict(temp=.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=sampling_temperature, fp16=True, max_batch_size=max_batch_size, chunk_size=chunk_size) ] zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))] zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps) del top_prior empty_cache() top_prior=None upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]] labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
def run(model, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) with t.no_grad(): save_outputs(model, device, hps)
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): # Example call: # model=5b, name=sample_5b, levels=3, sample_length_in_seconds=20, total_sample_length_in_seconds=180, # sr=44100, n_samples=6, hop_fraction=[0.5, 0.5, 0.125] from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def check_sample(): n_ctx = 8192 n_samples = 4 levels = 3 priors = [DummyPrior(n_ctx, level, levels) for level in range(levels)] max_total_length, offset, sample_length = 4134368, 0, n_ctx*8*4*4 y = t.tensor([max_total_length, offset, sample_length, 10, 1, -1, -1, -1, -1], dtype=t.long, device='cuda').view(1, 9).repeat(n_samples, 1) labels = [dict(y=y, info=[[]*n_samples]) for level in range(levels)] hps = Hyperparams({ 'levels': 3, 'sample_length': sample_length, 'n_segment': 2, 'n_ctx': n_ctx, 'n_tokens': 0, 'hop_lengths': [n_ctx//2, n_ctx//2, n_ctx//8], 'n_samples': n_samples, 'use_tokens': False }) test_ancestral_sample(labels, priors, hps) test_primed_sample(labels, priors, hps)
def run(model='1b_lyrics', config_file=None, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) print('Starting run') # Start with an empty configuration config_dict = {} # Load configuration from yaml file, if specified if config_file: print("Loading config file %s" % config_file) with open(config_file, 'r') as stream: config_dict = yaml.safe_load(stream) # Update the named parameters with config values (if set) model = config_dict.get('model', model) mode = config_dict.get('mode', mode) codes_file = config_dict.get('codes_file', codes_file) audio_file = config_dict.get('audio_file', audio_file) prompt_length_in_seconds = config_dict.get('prompt_length_in_seconds', prompt_length_in_seconds) port = config_dict.get('port', port) # Copy the config file to the output directory so # it's stored with the generated audio print("Copying config file to output directory") output_dir = config_dict['name'] output_config = os.path.join(output_dir, os.path.basename(config_file)) print(" %s -> %s" % (config_file, output_config)) if not os.path.isdir(output_dir): os.makedirs(output_dir) shutil.copyfile(config_file, output_config) # Update the config from any values specified on the command line. # The command line values will overwrite the ones loaded from the config file. config_dict.update(kwargs) print('Using these configuration values:') pprint.pprint(config_dict, indent=2) # If we're using a primer file, check that it exists if audio_file: if not os.path.exists(audio_file): print("Trying to use audio_file but it doesn't exist!") print("Place your priming audio here:") print(audio_file) sys.exit(-1) hps = Hyperparams(**config_dict) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) offset = 0 total_l = hps.total_sample_length_in_seconds * hps.sr #get metatags from somewhere, fixed or generated metatags = [ dict( artist="Unknown", genre="psychedelic", lyrics="nothing", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="hey come on", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="is this real", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="love is an elephant in a suit", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="you, yes you", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="why do you", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="no", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="i i i i i", total_length=total_l, offset=offset, ), ] with t.no_grad(): save_samples(model, device, hps, sample_hps, metatags)
def run(mode='ancestral', audio_file=None, prompt_length_in_seconds=12.0, port=29500): from jukebox.utils.dist_utils import setup_dist_from_mpi from jukebox.utils import queue # setup distributed communications rank, local_rank, device = setup_dist_from_mpi(port=port) while True: # connect to db db, cur = queue.connectdb() offset = 0 # get the next job job = queue.get_next_job(cur) if job: print(job) job_id = job['job_id'] kw = dict() kw['sr'] = 44100 kw['n_samples'] = 3 kw['hop_fraction'] = (0.5, 0.5, 0.25) kw['model'] = '5b_lyrics' kw['levels'] = 3 kw['sample_length_in_seconds'] = int(job['params']['length']) kw['total_sample_length_in_seconds'] = int(job['params']['length']) kw['n_samples'] = 15 if '5b_lyrics' == job['params'][ 'model'] else 16 kw['job_id'] = job_id kw['name'] = job['params']['name'] hps = Hyperparams(kw) # artist, lyrics, genre metas = Hyperparams( dict( artist=job['params']['artist'], genre=job['params']['genre'], lyrics=job['params']['lyrics'], total_length=job['params']['length'] * kw['sr'], # remove hardcoded sr offset=offset)) print(hps) sample_hps = Hyperparams( dict(mode=mode, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) # Lock the job queue.lock(cur, job_id) # Start the job queue.update_status(cur, job_id, "top_started") # Log the URL curl = subprocess.Popen(os.path.expanduser('./get_ip.sh'), stdout=subprocess.PIPE) ip, _ = curl.communicate() # (ip, error) url = "http://{}/jukebox/{}_{}/".format(ip.decode().strip(), job_id, job['params']['name']) queue.log( cur, job_id, "URL: http://{}/jukebox/{}_{}/".format(ip.decode().strip(), job_id, job['params']['name'])) # close db connection to avoid timeout error after sampling queue.closedb(db) # Run the full generating script here with t.no_grad(): save_samples(job['params']['model'], device, hps, sample_hps, [metas]) # FINISH # open fresh db connection db, cur = queue.connectdb() # update status queue.update_status(cur, job_id, "upsampling_done") queue.closedb(db) else: # pause the program for a minute and check back for new jobs print('Zzz...') time.sleep(60)
import pickle import random from IPython.display import Audio from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model from jukebox.hparams import Hyperparams, setup_hparams from jukebox.sample import sample_single_window, _sample, \ sample_partial_window, upsample, \ load_prompts from jukebox.utils.dist_utils import setup_dist_from_mpi from jukebox.utils.torch_utils import empty_cache port = random.randint(10000, 20000) rank, local_rank, device = setup_dist_from_mpi(port=port) model = "5b_lyrics" # or "1b_lyrics" hps = Hyperparams() hps.sr = 44100 hps.n_samples = 3 if model == '5b_lyrics' else 16 # Specifies the directory to save the sample in. # We set this to the Google Drive mount point. if len(sys.argv) > 1: this_run_slug = sys.argv[1] else: this_run_slug = "co_compose_synth2" hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/' meta = pickle.load(open(f'{hps.name}meta.p', "rb")) hps.sample_length = 1048576 if model == "5b_lyrics" else 786432
sample_options = { "name": "sample_5b", "levels": 3, "sample_length_in_seconds": 20, "total_sample_length_in_seconds": 180, "sr": 44100, "n_samples": 6, "hop_fraction": [0.5, 0.5, 0.125] } train_options = {"bs": 1, "labels": False} rank, local_rank, device = setup_dist_from_mpi(port=29500) print("Device: {}".format(device)) hps = Hyperparams(**sample_options) hps = setup_hparams( "vqvae", dict(sample_length=hps.get('sample_length', 0), sample_length_in_seconds=hps.get('sample_length_in_seconds', 0), labels=False, bs=1)) # print(hps) vqvae = make_vqvae(hps, 'cuda:0') def compute_metrics(vqvae, hps, output_folder): json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json' mp3_dict = load_json(json_path) csv = {