def init_dataset(self, hps): # Load list of files and starts/durations files = librosa.util.find_files(f'{hps.audio_files_dir}', ['mp3', 'opus', 'm4a', 'aac', 'wav']) print_all(f"Found {len(files)} files. Getting durations") cache = dist.get_rank() % 8 == 0 if dist.is_available() else True durations = np.array([get_duration_sec(file, cache=cache) * self.sr for file in files]) # Could be approximate self.filter(files, durations) if self.labels: self.labeller = Labeller(hps.max_bow_genre_size, hps.n_tokens, self.sample_length, v3=hps.labels_v3)
def filter(self, files, durations): # Remove files too short or too long keep = [] for i in range(len(files)): if durations[i] / self.sr < self.min_duration: continue if durations[i] / self.sr >= self.max_duration: continue keep.append(i) print_all(f'self.sr={self.sr}, min: {self.min_duration}, max: {self.max_duration}') print_all(f"Keeping {len(keep)} of {len(files)} files") self.files = [files[i] for i in keep] self.durations = [int(durations[i]) for i in keep] self.cumsum = np.cumsum(self.durations)
def make_vqvae(hps, device='cuda'): from jukebox.vqvae.vqvae import VQVAE block_kwargs = dict( width=hps.width, depth=hps.depth, m_conv=hps.m_conv, dilation_growth_rate=hps.dilation_growth_rate, dilation_cycle=hps.dilation_cycle, reverse_decoder_dilation=hps.vqvae_reverse_decoder_dilation) if not hps.sample_length: assert hps.sample_length_in_seconds != 0 downsamples = calculate_strides(hps.strides_t, hps.downs_t) top_raw_to_tokens = np.prod(downsamples) hps.sample_length = (hps.sample_length_in_seconds * hps.sr // top_raw_to_tokens) * top_raw_to_tokens print( f"Setting sample length to {hps.sample_length} (i.e. {hps.sample_length/hps.sr} seconds) to be multiple of {top_raw_to_tokens}" ) vqvae = VQVAE(input_shape=(hps.sample_length, 1), levels=hps.levels, downs_t=hps.downs_t, strides_t=hps.strides_t, emb_width=hps.emb_width, l_bins=hps.l_bins, mu=hps.l_mu, commit=hps.commit, spectral=hps.spectral, multispectral=hps.multispectral, multipliers=hps.hvqvae_multipliers, use_bottleneck=hps.use_bottleneck, **block_kwargs) vqvae = vqvae.to(device) restore(hps, vqvae, hps.restore_vqvae) if hps.train and not hps.prior: print_all(f"Loading vqvae in train mode") if hps.restore_vqvae != '': print_all("Reseting bottleneck emas") for level, bottleneck in enumerate(vqvae.bottleneck.level_blocks): num_samples = hps.sample_length downsamples = calculate_strides(hps.strides_t, hps.downs_t) raw_to_tokens = np.prod(downsamples[:level + 1]) num_tokens = (num_samples // raw_to_tokens) * dist.get_world_size() bottleneck.restore_k(num_tokens=num_tokens, threshold=hps.revival_threshold) else: print_all(f"Loading vqvae in eval mode") vqvae.eval() freeze_model(vqvae) return vqvae
def forward(self, x, update_k=True): N, width, T = x.shape # Preprocess x, prenorm = self.preprocess(x) # Init k if not inited if update_k and not self.init: self.init_k(x) # self.k = self.get_cur_embeddings() # Quantise and dequantise through bottleneck x_l, fit = self.quantise(x) x_d = self.dequantise(x_l) # Update embeddings if update_k: update_metrics = self.update_k(x, x_l) else: update_metrics = {} # Loss commit_loss = t.norm(x_d.detach() - x)**2 / np.prod(x.shape) # q_latent_loss = t.reduce_mean((x_d - tf.stop_gradient(inputs)) ** 2) print_all("DOR PRINTS SHAPES") print_all(x_d.shape) # same shape as print_all(x_l.shape) print_all(x.shape) # this one # Passthrough x_d = x + (x_d - x).detach() # Postprocess x_l, x_d = self.postprocess(x_l, x_d, (N, T)) return x_l, x_d, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics)
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps): model.train() orig_model.train() if hps.prior: _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss") else: _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk") print_all(data_processor.train_loader) print_all(len(data_processor.train_loader)) for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (logger.iters % hps.save_iters == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # Forward x_out, loss, _metrics = model(x, **forw_kwargs) # Backward loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()), scalar=scalar, fp16=hps.fp16, logger=logger) # Skip step if overflow grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX) if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0: zero_grad(orig_model) continue # Step opt. Divide by scale to include clipping and fp16 scaling logger.step() opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale)) zero_grad(orig_model) lr = hps.lr if shd is None else shd.get_lr()[0] if shd is not None: shd.step() if ema is not None: ema.step() next_lr = hps.lr if shd is None else shd.get_lr()[0] finished_training = (next_lr == 0.0) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph _metrics["gn"] = grad_norm _metrics["lr"] = lr _metrics["lg_loss_scale"] = np.log2(scale) # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(key, val, x.shape[0]) if logger.iters % hps.log_steps == 0: logger.add_scalar(key, _metrics[key]) # Save checkpoint with t.no_grad(): if hps.save and (logger.iters % hps.save_iters == 1 or finished_training): if ema is not None: ema.swap() orig_model.eval() name = 'latest' if hps.prior else f'step_{logger.iters}' if dist.get_rank() % 8 == 0: save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps) orig_model.train() if ema is not None: ema.swap() # Sample with t.no_grad(): if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training: if hps.prior: sample_prior(orig_model, ema, logger, x_in, y, hps) # Input/Output with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) print("Hey there") logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) print("by there") if finished_training: dist.barrier() exit() logger.close_range() return {key: metrics.avg(key) for key in _metrics.keys()}
def make_prior(hps, vqvae, device='cuda'): from jukebox.prior.prior import SimplePrior prior_kwargs = dict(input_shape=(hps.n_ctx, ), bins=hps.l_bins, width=hps.prior_width, depth=hps.prior_depth, heads=hps.heads, attn_order=hps.attn_order, blocks=hps.blocks, spread=hps.spread, attn_dropout=hps.attn_dropout, resid_dropout=hps.resid_dropout, emb_dropout=hps.emb_dropout, zero_out=hps.zero_out, res_scale=hps.res_scale, pos_init=hps.pos_init, init_scale=hps.init_scale, m_attn=hps.m_attn, m_mlp=hps.m_mlp, checkpoint_res=hps.c_res if hps.train else 0, checkpoint_attn=hps.c_attn if hps.train else 0, checkpoint_mlp=hps.c_mlp if hps.train else 0) x_cond_kwargs = dict( out_width=hps.prior_width, init_scale=hps.init_scale, width=hps.cond_width, depth=hps.cond_depth, m_conv=hps.cond_m_conv, dilation_growth_rate=hps.cond_dilation_growth_rate, dilation_cycle=hps.cond_dilation_cycle, zero_out=hps.cond_zero_out, res_scale=hps.cond_res_scale, checkpoint_res=hps.cond_c_res) # have to keep this else names wrong y_cond_kwargs = dict(out_width=hps.prior_width, init_scale=hps.init_scale, y_bins=hps.y_bins, t_bins=hps.t_bins, t_ranges=hps.t_ranges, max_bow_genre_size=hps.max_bow_genre_size) if hps.use_tokens and not hps.single_enc_dec: prime_kwargs = dict( use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction, n_tokens=hps.n_tokens, bins=hps.n_vocab, width=hps.prime_width, depth=hps.prime_depth, heads=hps.prime_heads, attn_order=hps.prime_attn_order, blocks=hps.prime_blocks, spread=hps.prime_spread, attn_dropout=hps.prime_attn_dropout, resid_dropout=hps.prime_resid_dropout, emb_dropout=hps.prime_emb_dropout, zero_out=hps.prime_zero_out, res_scale=hps.prime_res_scale, pos_init=hps.prime_pos_init, init_scale=hps.prime_init_scale, m_attn=hps.prime_m_attn, m_mlp=hps.prime_m_mlp, checkpoint_res=hps.prime_c_res if hps.train else 0, checkpoint_attn=hps.prime_c_attn if hps.train else 0, checkpoint_mlp=hps.prime_c_mlp if hps.train else 0) else: prime_kwargs = dict(use_tokens=hps.use_tokens, prime_loss_fraction=hps.prime_loss_fraction, n_tokens=hps.n_tokens, bins=hps.n_vocab) # z_shapes for other levels given this level gets n_ctx codes rescale = lambda z_shape: (z_shape[0] * hps.n_ctx // vqvae.z_shapes[ hps.level][0], ) z_shapes = [rescale(z_shape) for z_shape in vqvae.z_shapes] prior = SimplePrior(z_shapes=z_shapes, l_bins=hps.l_bins, encoder=vqvae.encode, decoder=vqvae.decode, level=hps.level, downs_t=hps.downs_t, strides_t=hps.strides_t, labels=hps.labels, prior_kwargs=prior_kwargs, x_cond_kwargs=x_cond_kwargs, y_cond_kwargs=y_cond_kwargs, prime_kwargs=prime_kwargs, copy_input=hps.copy_input, labels_v3=hps.labels_v3, merged_decoder=hps.merged_decoder, single_enc_dec=hps.single_enc_dec) prior.alignment_head = hps.get('alignment_head', None) prior.alignment_layer = hps.get('alignment_layer', None) if hps.fp16_params: print_all("Converting to fp16 params") from jukebox.transformer.ops import _convert_conv_weights_to_fp16 prior.apply(_convert_conv_weights_to_fp16) prior = prior.to(device) restore(hps, prior, hps.restore_prior) if hps.train: print_all(f"Loading prior in train mode") pass else: print_all(f"Loading prior in eval mode") prior.eval() freeze_model(prior) return prior
def print_stats(self, hps): print_all( f"Train {len(self.train_dataset)} samples. Test {len(self.test_dataset)} samples" ) print_all(f'Train sampler: {self.train_sampler}') print_all(f'Train loader: {len(self.train_loader)}')