def backward(loss, params, scalar, fp16, logger):
    # Perform backward
    if not fp16:
        scale = 1.0
        loss.backward()
        gn = grad_norm(params, scale)
        return loss, scale, gn, False, False
    else:
        scale = scalar.get_scale()
        loss = (loss.float()) * scale
        overflow_loss = check_overflow(loss.item())
        overflow_loss = allreduce(int(overflow_loss), op=dist.ReduceOp.MAX) > 0
        if not overflow_loss:
            loss.backward()
            gn = grad_norm(params, scale)
            overflow_grad = check_overflow(gn)
            overflow_grad = allreduce(int(overflow_grad),
                                      op=dist.ReduceOp.MAX) > 0
            scalar.update_scale(overflow_grad)
        else:
            gn = 0.0
            overflow_grad = True
        loss = (loss.detach().float()
                ) / scale  # Should delete computation graph for overflow
        if logger.rank == 0:
            if loss > 12.: print(f"\nWarning. Loss is {loss}")
            if overflow_loss:
                print(
                    f"\nOverflow in forward. Loss {loss}, lgscale {np.log2(scale)}. Skipping batch completely (no backward, scale update)"
                )
            elif overflow_grad:
                print(
                    f"\nOverflow in backward. Loss {loss}, grad norm {gn}, lgscale {np.log2(scale)}, new lgscale {np.log2(scalar.get_scale())}"
                )
        return loss, scale, gn, overflow_loss, overflow_grad
Exemplo n.º 2
0
def calculate_bandwidth(dataset, hps, duration=600):
    hps = DefaultSTFTValues(hps)
    n_samples = int(dataset.sr * duration)
    l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank()
    spec_norm_total, spec_nelem = 0.0, 0.0
    while n_seen < n_samples:
        x = dataset[idx]
        if isinstance(x, (tuple, list)):
            x, y = x
        samples = x.astype(np.float64)
        stft = librosa.core.stft(np.mean(samples, axis=1),
                                 hps.n_fft,
                                 hop_length=hps.hop_length,
                                 win_length=hps.window_size)
        spec = np.absolute(stft)
        spec_norm_total += np.linalg.norm(spec)
        spec_nelem += 1
        n_seen += int(np.prod(samples.shape))
        l1 += np.sum(np.abs(samples))
        total += np.sum(samples)
        total_sq += np.sum(samples**2)
        idx += max(16, dist.get_world_size())

    if dist.is_available():
        from jukebox.utils.dist_utils import allreduce
        n_seen = allreduce(n_seen)
        total = allreduce(total)
        total_sq = allreduce(total_sq)
        l1 = allreduce(l1)
        spec_nelem = allreduce(spec_nelem)
        spec_norm_total = allreduce(spec_norm_total)

    mean = total / n_seen
    bandwidth = dict(l2=total_sq / n_seen - mean**2,
                     l1=l1 / n_seen,
                     spec=spec_norm_total / spec_nelem)
    print_once(bandwidth)
    return bandwidth
Exemplo n.º 3
0
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
    model.train()
    orig_model.train()
    if hps.prior:
        _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
    else:
        _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")

    print_all(data_processor.train_loader)
    print_all(len(data_processor.train_loader))

    for i, x in logger.get_range(data_processor.train_loader):
        if isinstance(x, (tuple, list)):
            x, y = x
        else:
            y = None

        x = x.to('cuda', non_blocking=True)
        if y is not None:
            y = y.to('cuda', non_blocking=True)

        x_in = x = audio_preprocess(x, hps)
        log_input_output = (logger.iters % hps.save_iters == 0)

        if hps.prior:
            forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
        else:
            forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

        # Forward
        x_out, loss, _metrics = model(x, **forw_kwargs)

        # Backward
        loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
                                                                         scalar=scalar, fp16=hps.fp16, logger=logger)
        # Skip step if overflow
        grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
        if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
            zero_grad(orig_model)
            continue

        # Step opt. Divide by scale to include clipping and fp16 scaling
        logger.step()
        opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
        zero_grad(orig_model)
        lr = hps.lr if shd is None else shd.get_lr()[0]
        if shd is not None: shd.step()
        if ema is not None: ema.step()
        next_lr = hps.lr if shd is None else shd.get_lr()[0]
        finished_training = (next_lr == 0.0)

        # Logging
        for key, val in _metrics.items():
            _metrics[key] = val.item()
        _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
        _metrics["gn"] = grad_norm
        _metrics["lr"] = lr
        _metrics["lg_loss_scale"] = np.log2(scale)

        # Average and log
        for key, val in _metrics.items():
            _metrics[key] = metrics.update(key, val, x.shape[0])
            if logger.iters % hps.log_steps == 0:
                logger.add_scalar(key, _metrics[key])

        # Save checkpoint
        with t.no_grad():
            if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
                if ema is not None: ema.swap()
                orig_model.eval()
                name = 'latest' if hps.prior else f'step_{logger.iters}'
                if dist.get_rank() % 8 == 0:
                    save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
                orig_model.train()
                if ema is not None: ema.swap()

        # Sample
        with t.no_grad():
            if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
                if hps.prior:
                    sample_prior(orig_model, ema, logger, x_in, y, hps)

        # Input/Output
        with t.no_grad():
            if log_input_output:
                log_inputs(orig_model, logger, x_in, y, x_out, hps)

        print("Hey there")
        logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
        print("by there")
        if finished_training:
            dist.barrier()
            exit()
    logger.close_range()
    return {key: metrics.avg(key) for key in _metrics.keys()}