예제 #1
0
def run(model, port=29500, **kwargs):
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)

    with t.no_grad():
        save_alignment(model, device, hps)
예제 #2
0
def run(hps="teeny", port=29500, **kwargs):
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
예제 #3
0
def run(model,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
예제 #4
0
    def restore_k(self):
        pass

class NoBottleneck(nn.Module):
    def __init__(self, levels):
        super().__init__()
        self.level_blocks = nn.ModuleList()
        self.levels = levels
        for level in range(levels):
            self.level_blocks.append(NoBottleneckBlock())

    def encode(self, xs):
        return xs

    def decode(self, zs, start_level=0, end_level=None):
        if end_level is None:
            end_level = self.levels
        return zs

    def forward(self, xs):
        zero = t.zeros(()).cuda()
        commit_losses = [zero for _ in range(self.levels)]
        metrics = [dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero) for _ in range(self.levels)]
        return xs, xs, commit_losses, metrics

if __name__ == '__main__':
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=29600)
    bottleneck = Bottleneck(256, 64, 0.99, 2).to(device)
    bottleneck.check()
예제 #5
0
import app.jukebox
import torch as t
import librosa
import os
from IPython.display import Audio
from app.jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from app.jukebox.hparams import Hyperparams, setup_hparams
from app.jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample
from app.jukebox.utils.dist_utils import setup_dist_from_mpi
from app.jukebox.utils.torch_utils import empty_cache
rank, local_rank, device = setup_dist_from_mpi()


def generate_audio(request):
    response = request
    return response
            for x_chunk in x_chunks:
                y_chunk = self.forward(x_chunk.contiguous(),
                                       encoder_kv=encoder_kv,
                                       sample=True)
                total_len += x_chunk.shape[1]
                self.check_cache(bs, total_len, False)
                y_chunks.append(y_chunk)
            y_forw_in_chunks = t.cat(y_chunks, dim=1)

            max_err = t.max(t.abs(y_forw - y_forw_in_chunks))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}"


if __name__ == '__main__':
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    setup_dist_from_mpi(port=29600)
    n_in = 16
    n_state = n_in * 2
    n_ctx = 6144
    n_head = 4
    n_depth = 12
    blocks = 64
    chunk_size = 8
    for attn_func in [0, 1, 2, 3, 6, 7]:
        encoder_dims = {0: 0, 1: 0, 2: 0, 3: 0, 6: 64, 7: 0}[attn_func]
        prime_len = {0: 0, 1: 0, 2: 0, 3: 0, 6: 0, 7: 384}[attn_func]
        attn = FactoredAttention(n_in,
                                 n_ctx + prime_len,
                                 n_state,
                                 n_head,
                                 mask=True,