def finish_reduce(self): for tag, layer, val, work in self.works: work.wait() if self.rank == 0: val = val.item()/dist.get_world_size() self.lw[layer].add_scalar(tag, val, self.iters) self.works = []
def make_vqvae(hps, device='cuda'): from app.jukebox.vqvae.vqvae import VQVAE block_kwargs = dict( width=hps.width, depth=hps.depth, m_conv=hps.m_conv, dilation_growth_rate=hps.dilation_growth_rate, dilation_cycle=hps.dilation_cycle, reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) if not hps.sample_length: assert hps.sample_length_in_seconds != 0 downsamples = calculate_strides(hps.strides_t, hps.downs_t) top_raw_to_tokens = np.prod(downsamples) hps.sample_length = (hps.sample_length_in_seconds * hps.sr // top_raw_to_tokens) * top_raw_to_tokens print( f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}" ) vqvae = VQVAE(input_shape=(hps.sample_length, 1), levels=hps.levels, downs_t=hps.downs_t, strides_t=hps.strides_t, emb_width=hps.emb_width, l_bins=hps.l_bins, mu=hps.l_mu, commit=hps.commit, spectral=hps.spectral, multispectral=hps.multispectral, multipliers=hps.hvqvae_multipliers, use_bottleneck=hps.use_bottleneck, **block_kwargs) vqvae = vqvae.to(device) restore_model(hps, vqvae, hps.restore_vqvae) if hps.train and not hps.prior: print_all(f"Loading vqvae in train mode") if hps.restore_vqvae != '': print_all("Reseting bottleneck emas") for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks): num_samples = hps.sample_length downsamples = calculate_strides(hps.strides_t, hps.downs_t) raw_to_tokens = np.prod(downsamples[:level + 1]) num_tokens = (num_samples // raw_to_tokens) * dist.get_world_size() bottleneck.restore_k(num_tokens=num_tokens, threshold=hps.revival_threshold) else: print_all(f"Loading vqvae in eval mode") vqvae.eval() freeze_model(vqvae) return vqvae
def allgather_lists(xs): bs = len(xs) total_bs = dist.get_world_size()*len(xs) lengths = torch.tensor([len(x) for x in xs], dtype=t.long, device='cuda') lengths = allgather(lengths) assert lengths.shape == (total_bs,) max_length = torch.max(lengths).item() xs = torch.tensor([[*x, *[0]*(max_length - len(x))] for x in xs], device='cuda') assert xs.shape == (bs, max_length), f'Expected {(bs, max_length)}, got {xs.shape}' xs = allgather(xs) assert xs.shape == (total_bs,max_length), f'Expected {(total_bs, max_length)}, got {xs.shape}' return [xs[i][:lengths[i]].cpu().numpy().tolist() for i in range(total_bs)]
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps): alignments = None for level in reversed(sample_levels): prior = priors[level] prior.cuda() empty_cache() # Set correct total_length, hop_length, labels and sampling_kwargs for level assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}" total_length = hps.sample_length // prior.raw_to_tokens hop_length = int(hps.hop_fraction[level] * prior.n_ctx) zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps) prior.cpu() empty_cache() # Decode sample x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0]) if dist.get_world_size() > 1: logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}" else: logdir = f"{hps.name}/level_{level}" if not os.path.exists(logdir): os.makedirs(logdir) t.save( dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar") save_wav(logdir, x, hps.sr) if alignments is None and priors[ -1] is not None and priors[-1].n_tokens > 0 and not isinstance( priors[-1].labeller, EmptyLabeller): alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps) save_html(logdir, x, zs, labels[-1], alignments, hps) return zs
def calculate_bandwidth(dataset, hps, duration=600): hps = DefaultSTFTValues(hps) n_samples = int(dataset.sr * duration) l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() spec_norm_total, spec_nelem = 0.0, 0.0 while n_seen < n_samples: x = dataset[idx] if isinstance(x, (tuple, list)): x, y = x samples = x.astype(np.float64) stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) spec = np.absolute(stft) spec_norm_total += np.linalg.norm(spec) spec_nelem += 1 n_seen += int(np.prod(samples.shape)) l1 += np.sum(np.abs(samples)) total += np.sum(samples) total_sq += np.sum(samples**2) idx += max(16, dist.get_world_size()) if dist.is_available(): from jukebox.utils.dist_utils import allreduce n_seen = allreduce(n_seen) total = allreduce(total) total_sq = allreduce(total_sq) l1 = allreduce(l1) spec_nelem = allreduce(spec_nelem) spec_norm_total = allreduce(spec_norm_total) mean = total / n_seen bandwidth = dict(l2=total_sq / n_seen - mean**2, l1=l1 / n_seen, spec=spec_norm_total / spec_nelem) print_once(bandwidth) return bandwidth
def allgather(x): xs = [torch.empty_like(x) for _ in range(dist.get_world_size())] dist.all_gather(xs, x) xs = torch.cat(xs, dim=0) return xs