Esempio n. 1
0
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length//prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
        zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

        if dist.get_world_size() > 1:
            logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
        else:
            logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller):
            alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
        save_html(logdir, x, zs, labels[-1], alignments, hps)
    return zs
Esempio n. 2
0
 def finish_reduce(self):
     for tag, layer, val, work in self.works:
         work.wait()
         if self.rank == 0:
             val = val.item() / dist.get_world_size()
             self.lw[layer].add_scalar(tag, val, self.iters)
     self.works = []
Esempio n. 3
0
def run(hps="teeny", port=29500, **kwargs):
    from app.jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
Esempio n. 4
0
def make_vqvae(hps, device='cuda'):
    from jukebox.vqvae.vqvae import VQVAE
    block_kwargs = dict(
        width=hps.width,
        depth=hps.depth,
        m_conv=hps.m_conv,
        dilation_growth_rate=hps.dilation_growth_rate,
        dilation_cycle=hps.dilation_cycle,
        reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)

    if not hps.sample_length:
        assert hps.sample_length_in_seconds != 0
        downsamples = calculate_strides(hps.strides_t, hps.downs_t)
        top_raw_to_tokens = np.prod(downsamples)
        hps.sample_length = (hps.sample_length_in_seconds * hps.sr //
                             top_raw_to_tokens) * top_raw_to_tokens
        print(
            f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}"
        )

    vqvae = VQVAE(input_shape=(hps.sample_length, 1),
                  levels=hps.levels,
                  downs_t=hps.downs_t,
                  strides_t=hps.strides_t,
                  emb_width=hps.emb_width,
                  l_bins=hps.l_bins,
                  mu=hps.l_mu,
                  commit=hps.commit,
                  spectral=hps.spectral,
                  multispectral=hps.multispectral,
                  multipliers=hps.hvqvae_multipliers,
                  use_bottleneck=hps.use_bottleneck,
                  **block_kwargs)

    vqvae = vqvae.to(device)
    restore(hps, vqvae, hps.restore_vqvae)
    if hps.train and not hps.prior:
        print_all(f"Loading vqvae in train mode")
        if hps.restore_vqvae != '':
            print_all("Reseting bottleneck emas")
            for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
                num_samples = hps.sample_length
                downsamples = calculate_strides(hps.strides_t, hps.downs_t)
                raw_to_tokens = np.prod(downsamples[:level + 1])
                num_tokens = (num_samples //
                              raw_to_tokens) * dist.get_world_size()
                bottleneck.restore_k(num_tokens=num_tokens,
                                     threshold=hps.revival_threshold)
    else:
        print_all(f"Loading vqvae in eval mode")
        vqvae.eval()
        freeze_model(vqvae)
    return vqvae
Esempio n. 5
0
def allgather_lists(xs):
    bs = len(xs)
    total_bs = dist.get_world_size()*len(xs)
    lengths = torch.tensor([len(x) for x in xs], dtype=t.long, device='cuda')
    lengths = allgather(lengths)
    assert lengths.shape == (total_bs,)
    max_length = torch.max(lengths).item()

    xs = torch.tensor([[*x, *[0]*(max_length - len(x))] for x in xs], device='cuda')
    assert xs.shape == (bs, max_length), f'Expected {(bs, max_length)}, got {xs.shape}'
    xs = allgather(xs)
    assert xs.shape == (total_bs,max_length), f'Expected {(total_bs, max_length)}, got {xs.shape}'

    return [xs[i][:lengths[i]].cpu().numpy().tolist() for i in range(total_bs)]
Esempio n. 6
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
Esempio n. 7
0
def allgather(x):
    xs = [torch.empty_like(x) for _ in range(dist.get_world_size())]
    dist.all_gather(xs, x)
    xs = torch.cat(xs, dim=0)
    return xs