Example #1
0
 def finish_reduce(self):
     for tag, layer, val, work in self.works:
         work.wait()
         if self.rank == 0:
             val = val.item()/dist.get_world_size()
             self.lw[layer].add_scalar(tag, val, self.iters)
     self.works = []
def make_vqvae(hps, device='cuda'):
    from app.jukebox.vqvae.vqvae import VQVAE
    block_kwargs = dict(
        width=hps.width,
        depth=hps.depth,
        m_conv=hps.m_conv,
        dilation_growth_rate=hps.dilation_growth_rate,
        dilation_cycle=hps.dilation_cycle,
        reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation)

    if not hps.sample_length:
        assert hps.sample_length_in_seconds != 0
        downsamples = calculate_strides(hps.strides_t, hps.downs_t)
        top_raw_to_tokens = np.prod(downsamples)
        hps.sample_length = (hps.sample_length_in_seconds * hps.sr //
                             top_raw_to_tokens) * top_raw_to_tokens
        print(
            f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}"
        )

    vqvae = VQVAE(input_shape=(hps.sample_length, 1),
                  levels=hps.levels,
                  downs_t=hps.downs_t,
                  strides_t=hps.strides_t,
                  emb_width=hps.emb_width,
                  l_bins=hps.l_bins,
                  mu=hps.l_mu,
                  commit=hps.commit,
                  spectral=hps.spectral,
                  multispectral=hps.multispectral,
                  multipliers=hps.hvqvae_multipliers,
                  use_bottleneck=hps.use_bottleneck,
                  **block_kwargs)

    vqvae = vqvae.to(device)
    restore_model(hps, vqvae, hps.restore_vqvae)
    if hps.train and not hps.prior:
        print_all(f"Loading vqvae in train mode")
        if hps.restore_vqvae != '':
            print_all("Reseting bottleneck emas")
            for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks):
                num_samples = hps.sample_length
                downsamples = calculate_strides(hps.strides_t, hps.downs_t)
                raw_to_tokens = np.prod(downsamples[:level + 1])
                num_tokens = (num_samples //
                              raw_to_tokens) * dist.get_world_size()
                bottleneck.restore_k(num_tokens=num_tokens,
                                     threshold=hps.revival_threshold)
    else:
        print_all(f"Loading vqvae in eval mode")
        vqvae.eval()
        freeze_model(vqvae)
    return vqvae
def allgather_lists(xs):
    bs = len(xs)
    total_bs = dist.get_world_size()*len(xs)
    lengths = torch.tensor([len(x) for x in xs], dtype=t.long, device='cuda')
    lengths = allgather(lengths)
    assert lengths.shape == (total_bs,)
    max_length = torch.max(lengths).item()

    xs = torch.tensor([[*x, *[0]*(max_length - len(x))] for x in xs], device='cuda')
    assert xs.shape == (bs, max_length), f'Expected {(bs, max_length)}, got {xs.shape}'
    xs = allgather(xs)
    assert xs.shape == (total_bs,max_length), f'Expected {(total_bs, max_length)}, got {xs.shape}'

    return [xs[i][:lengths[i]].cpu().numpy().tolist() for i in range(total_bs)]
Example #4
0
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length // prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level] * prior.n_ctx)
        zs = sample_level(zs, labels[level], sampling_kwargs[level], level,
                          prior, total_length, hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:],
                         start_level=level,
                         bs_chunks=zs[level].shape[0])

        if dist.get_world_size() > 1:
            logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
        else:
            logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(
            dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x),
            f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        if alignments is None and priors[
                -1] is not None and priors[-1].n_tokens > 0 and not isinstance(
                    priors[-1].labeller, EmptyLabeller):
            alignments = get_alignment(x, zs, labels[-1], priors[-1],
                                       sampling_kwargs[-1]['fp16'], hps)
        save_html(logdir, x, zs, labels[-1], alignments, hps)
    return zs
Example #5
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
def allgather(x):
    xs = [torch.empty_like(x) for _ in range(dist.get_world_size())]
    dist.all_gather(xs, x)
    xs = torch.cat(xs, dim=0)
    return xs