def backward(loss, params, scalar, fp16, logger): # Perform backward if not fp16: scale = 1.0 loss.backward() gn = grad_norm(params, scale) return loss, scale, gn, False, False else: scale = scalar.get_scale() loss = (loss.float()) * scale overflow_loss = check_overflow(loss.item()) overflow_loss = allreduce(int(overflow_loss), op=dist.ReduceOp.MAX) > 0 if not overflow_loss: loss.backward() gn = grad_norm(params, scale) overflow_grad = check_overflow(gn) overflow_grad = allreduce(int(overflow_grad), op=dist.ReduceOp.MAX) > 0 scalar.update_scale(overflow_grad) else: gn = 0.0 overflow_grad = True loss = (loss.detach().float() ) / scale # Should delete computation graph for overflow if logger.rank == 0: if loss > 12.: print(f"\nWarning. Loss is {loss}") if overflow_loss: print( f"\nOverflow in forward. Loss {loss}, lgscale {np.log2(scale)}. Skipping batch completely (no backward, scale update)" ) elif overflow_grad: print( f"\nOverflow in backward. Loss {loss}, grad norm {gn}, lgscale {np.log2(scale)}, new lgscale {np.log2(scalar.get_scale())}" ) return loss, scale, gn, overflow_loss, overflow_grad
def calculate_bandwidth(dataset, hps, duration=600): hps = DefaultSTFTValues(hps) n_samples = int(dataset.sr * duration) l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() spec_norm_total, spec_nelem = 0.0, 0.0 while n_seen < n_samples: x = dataset[idx] if isinstance(x, (tuple, list)): x, y = x samples = x.astype(np.float64) stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) spec = np.absolute(stft) spec_norm_total += np.linalg.norm(spec) spec_nelem += 1 n_seen += int(np.prod(samples.shape)) l1 += np.sum(np.abs(samples)) total += np.sum(samples) total_sq += np.sum(samples**2) idx += max(16, dist.get_world_size()) if dist.is_available(): from jukebox.utils.dist_utils import allreduce n_seen = allreduce(n_seen) total = allreduce(total) total_sq = allreduce(total_sq) l1 = allreduce(l1) spec_nelem = allreduce(spec_nelem) spec_norm_total = allreduce(spec_norm_total) mean = total / n_seen bandwidth = dict(l2=total_sq / n_seen - mean**2, l1=l1 / n_seen, spec=spec_norm_total / spec_nelem) print_once(bandwidth) return bandwidth
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps): model.train() orig_model.train() if hps.prior: _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss") else: _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk") print_all(data_processor.train_loader) print_all(len(data_processor.train_loader)) for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (logger.iters % hps.save_iters == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # Forward x_out, loss, _metrics = model(x, **forw_kwargs) # Backward loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()), scalar=scalar, fp16=hps.fp16, logger=logger) # Skip step if overflow grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX) if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0: zero_grad(orig_model) continue # Step opt. Divide by scale to include clipping and fp16 scaling logger.step() opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale)) zero_grad(orig_model) lr = hps.lr if shd is None else shd.get_lr()[0] if shd is not None: shd.step() if ema is not None: ema.step() next_lr = hps.lr if shd is None else shd.get_lr()[0] finished_training = (next_lr == 0.0) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph _metrics["gn"] = grad_norm _metrics["lr"] = lr _metrics["lg_loss_scale"] = np.log2(scale) # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(key, val, x.shape[0]) if logger.iters % hps.log_steps == 0: logger.add_scalar(key, _metrics[key]) # Save checkpoint with t.no_grad(): if hps.save and (logger.iters % hps.save_iters == 1 or finished_training): if ema is not None: ema.swap() orig_model.eval() name = 'latest' if hps.prior else f'step_{logger.iters}' if dist.get_rank() % 8 == 0: save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps) orig_model.train() if ema is not None: ema.swap() # Sample with t.no_grad(): if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training: if hps.prior: sample_prior(orig_model, ema, logger, x_in, y, hps) # Input/Output with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) print("Hey there") logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) print("by there") if finished_training: dist.barrier() exit() logger.close_range() return {key: metrics.avg(key) for key in _metrics.keys()}