def run(hps="teeny", port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi audio_database = kwargs['audio_database'] del kwargs['audio_database'] rank, local_rank, device = setup_dist_from_mpi(port=port) print('device:', device) print("hps setup") hps = setup_hparams(hps, kwargs) hps.ngpus = 0 hps.nworkers = 0 hps.argv = " ".join(sys.argv) hps.bs_sample, hps.nworkers, hps.bs = 1, 1, 1 hps.bs_sample = hps.nworkers = hps.bs print("setting up database") # Setup dataset data_processor = DataProcessor(hps, audio_database) print("midi chunk call") for idx in range(8868): chunk = data_processor.dataset.get_midi_chunk(idx) if chunk.shape != (95, 128): print(chunk.shape) raise RuntimeError('It failed')
def run(model, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) with t.no_grad(): save_outputs(model, device, hps)
def run(hps="teeny", port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = setup_hparams(hps, kwargs) hps.ngpus = dist.get_world_size() hps.argv = " ".join(sys.argv) hps.bs_sample = hps.nworkers = hps.bs # Setup dataset data_processor = DataProcessor(hps) # Setup models vqvae = make_vqvae(hps, device) print_once(f"Parameters VQVAE:{count_parameters(vqvae)}") if hps.prior: prior = make_prior(hps, vqvae, device) print_once(f"Parameters Prior:{count_parameters(prior)}") model = prior else: model = vqvae # Setup opt, ema and distributed_model. opt, shd, scalar = get_optimizer(model, hps) ema = get_ema(model, hps) distributed_model = get_ddp(model, hps) logger, metrics = init_logging(hps, local_rank, rank) logger.iters = model.step # Run training, eval, sample for epoch in range(hps.curr_epoch, hps.epochs): metrics.reset() data_processor.set_epoch(epoch) if hps.train: train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps) train_metrics['epoch'] = epoch if rank == 0: print( 'Train', ' '.join([ f'{key}: {val:0.4f}' for key, val in train_metrics.items() ])) dist.barrier() if hps.test: if ema: ema.swap() test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps) test_metrics['epoch'] = epoch if rank == 0: print( 'Ema', ' '.join([ f'{key}: {val:0.4f}' for key, val in test_metrics.items() ])) dist.barrier() if ema: ema.swap() dist.barrier()
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def run(**kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi model = "1b_lyrics" port = 29500 rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams() hps.sr = 44100 hps.n_samples = 1 hps.name = kwargs["sample_name"] chunk_size = 32 max_batch_size = 16 hps.levels = 3 hps.hop_fraction = [.5,.5,.125] vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device) top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device) sample_length_in_seconds = kwargs["sample_length"] hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate' metas = [dict( artist = kwargs["artist"], genre = kwargs["genre"], total_length = hps.sample_length, offset = 0, lyrics = kwargs["lyrics"], ), ] * hps.n_samples labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')] sampling_temperature = .98 lower_batch_size = 16 max_batch_size = 16 lower_level_chunk_size = 32 chunk_size = 32 sampling_kwargs = [ dict(temp=.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=sampling_temperature, fp16=True, max_batch_size=max_batch_size, chunk_size=chunk_size) ] zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))] zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps) del top_prior empty_cache() top_prior=None upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]] labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): # Example call: # model=5b, name=sample_5b, levels=3, sample_length_in_seconds=20, total_sample_length_in_seconds=180, # sr=44100, n_samples=6, hop_fraction=[0.5, 0.5, 0.125] from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def run(model='1b_lyrics', config_file=None, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) print('Starting run') # Start with an empty configuration config_dict = {} # Load configuration from yaml file, if specified if config_file: print("Loading config file %s" % config_file) with open(config_file, 'r') as stream: config_dict = yaml.safe_load(stream) # Update the named parameters with config values (if set) model = config_dict.get('model', model) mode = config_dict.get('mode', mode) codes_file = config_dict.get('codes_file', codes_file) audio_file = config_dict.get('audio_file', audio_file) prompt_length_in_seconds = config_dict.get('prompt_length_in_seconds', prompt_length_in_seconds) port = config_dict.get('port', port) # Copy the config file to the output directory so # it's stored with the generated audio print("Copying config file to output directory") output_dir = config_dict['name'] output_config = os.path.join(output_dir, os.path.basename(config_file)) print(" %s -> %s" % (config_file, output_config)) if not os.path.isdir(output_dir): os.makedirs(output_dir) shutil.copyfile(config_file, output_config) # Update the config from any values specified on the command line. # The command line values will overwrite the ones loaded from the config file. config_dict.update(kwargs) print('Using these configuration values:') pprint.pprint(config_dict, indent=2) # If we're using a primer file, check that it exists if audio_file: if not os.path.exists(audio_file): print("Trying to use audio_file but it doesn't exist!") print("Place your priming audio here:") print(audio_file) sys.exit(-1) hps = Hyperparams(**config_dict) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) offset = 0 total_l = hps.total_sample_length_in_seconds * hps.sr #get metatags from somewhere, fixed or generated metatags = [ dict( artist="Unknown", genre="psychedelic", lyrics="nothing", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="hey come on", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="is this real", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="love is an elephant in a suit", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="you, yes you", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="why do you", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="no", total_length=total_l, offset=offset, ), dict( artist="Unknown", genre="psychedelic", lyrics="i i i i i", total_length=total_l, offset=offset, ), ] with t.no_grad(): save_samples(model, device, hps, sample_hps, metatags)
self.check_cache(bs, n, False) y_chunk = self.forward(x_chunk, encoder_kv=encoder_kv, sample=True) y_chunks.append(y_chunk) n += x_chunk.shape[1] self.check_cache(bs, n, False) y_forw_in_chunks = t.cat(y_chunks, dim=1) max_err = t.max(t.abs(y_forw - y_forw_in_chunks)) assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}" if __name__ == '__main__': from jukebox.utils.dist_utils import setup_dist_from_mpi setup_dist_from_mpi(port=29600) n_in = 16 n_ctx = 192 n_head = 4 n_depth = 12 blocks = 16 for attn_order in [0, 2, 6]: encoder_dims = {0: 0, 2: 0, 6: 64}[attn_order] prior = Transformer(n_in, n_ctx, n_head, n_depth, mask=True, attn_order=attn_order, encoder_dims=encoder_dims, blocks=blocks).cuda()
def __init__(self, levels): super().__init__() self.level_blocks = nn.ModuleList() self.levels = levels for level in range(levels): self.level_blocks.append(NoBottleneckBlock()) def encode(self, xs): return xs def decode(self, zs, start_level=0, end_level=None): if end_level is None: end_level = self.levels return zs def forward(self, xs): zero = t.zeros(()).cuda() commit_losses = [zero for _ in range(self.levels)] metrics = [ dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels) ] return xs, xs, commit_losses, metrics if __name__ == '__main__': from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=29600) bottleneck = Bottleneck(256, 64, 0.99, 2).to(device) bottleneck.check()
def run(mode='ancestral', audio_file=None, prompt_length_in_seconds=12.0, port=29500): from jukebox.utils.dist_utils import setup_dist_from_mpi from jukebox.utils import queue # setup distributed communications rank, local_rank, device = setup_dist_from_mpi(port=port) while True: # connect to db db, cur = queue.connectdb() offset = 0 # get the next job job = queue.get_next_job(cur) if job: print(job) job_id = job['job_id'] kw = dict() kw['sr'] = 44100 kw['n_samples'] = 3 kw['hop_fraction'] = (0.5, 0.5, 0.25) kw['model'] = '5b_lyrics' kw['levels'] = 3 kw['sample_length_in_seconds'] = int(job['params']['length']) kw['total_sample_length_in_seconds'] = int(job['params']['length']) kw['n_samples'] = 15 if '5b_lyrics' == job['params'][ 'model'] else 16 kw['job_id'] = job_id kw['name'] = job['params']['name'] hps = Hyperparams(kw) # artist, lyrics, genre metas = Hyperparams( dict( artist=job['params']['artist'], genre=job['params']['genre'], lyrics=job['params']['lyrics'], total_length=job['params']['length'] * kw['sr'], # remove hardcoded sr offset=offset)) print(hps) sample_hps = Hyperparams( dict(mode=mode, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) # Lock the job queue.lock(cur, job_id) # Start the job queue.update_status(cur, job_id, "top_started") # Log the URL curl = subprocess.Popen(os.path.expanduser('./get_ip.sh'), stdout=subprocess.PIPE) ip, _ = curl.communicate() # (ip, error) url = "http://{}/jukebox/{}_{}/".format(ip.decode().strip(), job_id, job['params']['name']) queue.log( cur, job_id, "URL: http://{}/jukebox/{}_{}/".format(ip.decode().strip(), job_id, job['params']['name'])) # close db connection to avoid timeout error after sampling queue.closedb(db) # Run the full generating script here with t.no_grad(): save_samples(job['params']['model'], device, hps, sample_hps, [metas]) # FINISH # open fresh db connection db, cur = queue.connectdb() # update status queue.update_status(cur, job_id, "upsampling_done") queue.closedb(db) else: # pause the program for a minute and check back for new jobs print('Zzz...') time.sleep(60)