def run(model, port=29500, **kwargs): from app.jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) with t.no_grad(): save_alignment(model, device, hps)
def run(hps="teeny", port=29500, **kwargs): from app.jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = setup_hparams(hps, kwargs) hps.ngpus = dist.get_world_size() hps.argv = " ".join(sys.argv) hps.bs_sample = hps.nworkers = hps.bs # Setup dataset data_processor = DataProcessor(hps) # Setup models vqvae = make_vqvae(hps, device) print_once(f"Parameters VQVAE:{count_parameters(vqvae)}") if hps.prior: prior = make_prior(hps, vqvae, device) print_once(f"Parameters Prior:{count_parameters(prior)}") model = prior else: model = vqvae # Setup opt, ema and distributed_model. opt, shd, scalar = get_optimizer(model, hps) ema = get_ema(model, hps) distributed_model = get_ddp(model, hps) logger, metrics = init_logging(hps, local_rank, rank) logger.iters = model.step # Run training, eval, sample for epoch in range(hps.curr_epoch, hps.epochs): metrics.reset() data_processor.set_epoch(epoch) if hps.train: train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps) train_metrics['epoch'] = epoch if rank == 0: print( 'Train', ' '.join([ f'{key}: {val:0.4f}' for key, val in train_metrics.items() ])) dist.barrier() if hps.test: if ema: ema.swap() test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps) test_metrics['epoch'] = epoch if rank == 0: print( 'Ema', ' '.join([ f'{key}: {val:0.4f}' for key, val in test_metrics.items() ])) dist.barrier() if ema: ema.swap() dist.barrier()
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs): from app.jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams(**kwargs) sample_hps = Hyperparams( dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds)) with t.no_grad(): save_samples(model, device, hps, sample_hps)
def restore_k(self): pass class NoBottleneck(nn.Module): def __init__(self, levels): super().__init__() self.level_blocks = nn.ModuleList() self.levels = levels for level in range(levels): self.level_blocks.append(NoBottleneckBlock()) def encode(self, xs): return xs def decode(self, zs, start_level=0, end_level=None): if end_level is None: end_level = self.levels return zs def forward(self, xs): zero = t.zeros(()).cuda() commit_losses = [zero for _ in range(self.levels)] metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)] return xs, xs, commit_losses, metrics if __name__ == '__main__': from app.jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=29600) bottleneck = Bottleneck(256, 64, 0.99, 2).to(device) bottleneck.check()
import app.jukebox import torch as t import librosa import os from IPython.display import Audio from app.jukebox.make_models import make_vqvae, make_prior, MODELS, make_model from app.jukebox.hparams import Hyperparams, setup_hparams from app.jukebox.sample import sample_single_window, _sample, \ sample_partial_window, upsample from app.jukebox.utils.dist_utils import setup_dist_from_mpi from app.jukebox.utils.torch_utils import empty_cache rank, local_rank, device = setup_dist_from_mpi() def generate_audio(request): response = request return response
for x_chunk in x_chunks: y_chunk = self.forward(x_chunk.contiguous(), encoder_kv=encoder_kv, sample=True) total_len += x_chunk.shape[1] self.check_cache(bs, total_len, False) y_chunks.append(y_chunk) y_forw_in_chunks = t.cat(y_chunks, dim=1) max_err = t.max(t.abs(y_forw - y_forw_in_chunks)) assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}" if __name__ == '__main__': from app.jukebox.utils.dist_utils import setup_dist_from_mpi setup_dist_from_mpi(port=29600) n_in = 16 n_state = n_in * 2 n_ctx = 6144 n_head = 4 n_depth = 12 blocks = 64 chunk_size = 8 for attn_func in [0, 1, 2, 3, 6, 7]: encoder_dims = {0: 0, 1: 0, 2: 0, 3: 0, 6: 64, 7: 0}[attn_func] prime_len = {0: 0, 1: 0, 2: 0, 3: 0, 6: 0, 7: 384}[attn_func] attn = FactoredAttention(n_in, n_ctx + prime_len, n_state, n_head, mask=True,