def make_model(model, device, hps, levels=None): vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length=hps.get('sample_length', 0), sample_length_in_seconds=hps.get('sample_length_in_seconds', 0))), device) hps.sample_length = vqvae.sample_length if levels is None: levels = range(len(priors)) priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels] return vqvae, priors
def run(**kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi model = "1b_lyrics" port = 29500 rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams() hps.sr = 44100 hps.n_samples = 1 hps.name = kwargs["sample_name"] chunk_size = 32 max_batch_size = 16 hps.levels = 3 hps.hop_fraction = [.5,.5,.125] vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device) top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device) sample_length_in_seconds = kwargs["sample_length"] hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate' metas = [dict( artist = kwargs["artist"], genre = kwargs["genre"], total_length = hps.sample_length, offset = 0, lyrics = kwargs["lyrics"], ), ] * hps.n_samples labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')] sampling_temperature = .98 lower_batch_size = 16 max_batch_size = 16 lower_level_chunk_size = 32 chunk_size = 32 sampling_kwargs = [ dict(temp=.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=sampling_temperature, fp16=True, max_batch_size=max_batch_size, chunk_size=chunk_size) ] zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))] zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps) del top_prior empty_cache() top_prior=None upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]] labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
def run(hps="teeny", port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi audio_database = kwargs['audio_database'] del kwargs['audio_database'] rank, local_rank, device = setup_dist_from_mpi(port=port) print('device:', device) print("hps setup") hps = setup_hparams(hps, kwargs) hps.ngpus = 0 hps.nworkers = 0 hps.argv = " ".join(sys.argv) hps.bs_sample, hps.nworkers, hps.bs = 1, 1, 1 hps.bs_sample = hps.nworkers = hps.bs print("setting up database") # Setup dataset data_processor = DataProcessor(hps, audio_database) print("midi chunk call") for idx in range(8868): chunk = data_processor.dataset.get_midi_chunk(idx) if chunk.shape != (95, 128): print(chunk.shape) raise RuntimeError('It failed')
def run(hps="teeny", port=29500, **kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = setup_hparams(hps, kwargs) hps.ngpus = dist.get_world_size() hps.argv = " ".join(sys.argv) hps.bs_sample = hps.nworkers = hps.bs # Setup dataset data_processor = DataProcessor(hps) # Setup models vqvae = make_vqvae(hps, device) print_once(f"Parameters VQVAE:{count_parameters(vqvae)}") if hps.prior: prior = make_prior(hps, vqvae, device) print_once(f"Parameters Prior:{count_parameters(prior)}") model = prior else: model = vqvae # Setup opt, ema and distributed_model. opt, shd, scalar = get_optimizer(model, hps) ema = get_ema(model, hps) distributed_model = get_ddp(model, hps) logger, metrics = init_logging(hps, local_rank, rank) logger.iters = model.step # Run training, eval, sample for epoch in range(hps.curr_epoch, hps.epochs): metrics.reset() data_processor.set_epoch(epoch) if hps.train: train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps) train_metrics['epoch'] = epoch if rank == 0: print( 'Train', ' '.join([ f'{key}: {val:0.4f}' for key, val in train_metrics.items() ])) dist.barrier() if hps.test: if ema: ema.swap() test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps) test_metrics['epoch'] = epoch if rank == 0: print( 'Ema', ' '.join([ f'{key}: {val:0.4f}' for key, val in test_metrics.items() ])) dist.barrier() if ema: ema.swap() dist.barrier()
def test_dataset_loader(): from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess from jukebox.hparams import setup_hparams from jukebox.data.files_dataset import FilesAudioDataset hps = setup_hparams("teeny", {}) hps.sr = 22050 # 44100 hps.hop_length = 512 hps.labels = False hps.channels = 2 hps.aug_shift = False hps.bs = 2 hps.nworkers = 2 # Getting 20 it/s with 2 workers, 10 it/s with 1 worker print(hps) dataset = hps.dataset root = hps.root from tensorboardX import SummaryWriter sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr] writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs') dataset = FilesAudioDataset(hps) print("Length of dataset", len(dataset)) # Torch Loader collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0) sampler = DistributedSampler(dataset) train_loader = DataLoader(dataset, batch_size=hps.bs, num_workers=hps.nworkers, pin_memory=False, sampler=sampler, drop_last=True, collate_fn=collate_fn) dist.barrier() sampler.set_epoch(0) for i, x in enumerate(tqdm(train_loader)): x = x.to('cuda', non_blocking=True) for j, aud in enumerate(x): writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote in") x = audio_preprocess(x, hps) x = audio_postprocess(x, hps) for j, aud in enumerate(x): writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote out") dist.barrier() break
this_run_slug = sys.argv[1] else: this_run_slug = "co_compose_synth2" hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/' meta = pickle.load(open(f'{hps.name}meta.p', "rb")) hps.sample_length = 1048576 if model == "5b_lyrics" else 786432 chunk_size = 16 if model == "5b_lyrics" else 32 max_batch_size = 3 if model == "5b_lyrics" else 16 hps.hop_fraction = [.5, .5, .125] hps.levels = 3 vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length=hps.sample_length)), device) metas_1 = meta[0] metas_2 = meta[1] print(metas_1) print(metas_2) zs = t.load(f'{hps.name}zs-top-level-final.t') top_prior_raw_to_tokens = 128 hps.sample_length = zs[2].shape[1] * top_prior_raw_to_tokens upsamplers = [
"sample_length_in_seconds": 20, "total_sample_length_in_seconds": 180, "sr": 44100, "n_samples": 6, "hop_fraction": [0.5, 0.5, 0.125] } train_options = {"bs": 1, "labels": False} rank, local_rank, device = setup_dist_from_mpi(port=29500) print("Device: {}".format(device)) hps = Hyperparams(**sample_options) hps = setup_hparams( "vqvae", dict(sample_length=hps.get('sample_length', 0), sample_length_in_seconds=hps.get('sample_length_in_seconds', 0), labels=False, bs=1)) # print(hps) vqvae = make_vqvae(hps, 'cuda:0') def compute_metrics(vqvae, hps, output_folder): json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json' mp3_dict = load_json(json_path) csv = { "client": [], "media_id": [], "external_id": [], "s3_key": [],