Exemple #1
0
def evaluate(model, orig_model, logger, metrics, data_processor, hps):
    model.eval()
    orig_model.eval()
    if hps.prior:
        _print_keys = dict(l="loss", bpd="bpd")
    else:
        _print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss")

    with t.no_grad():
        for i, x in logger.get_range(data_processor.train_loader):
            if isinstance(x, (tuple, list)):
                x, y = x
            else:
                y = None

            x = x.to('cuda', non_blocking=True)
            if y is not None:
                y = y.to('cuda', non_blocking=True)

            x_in = x = audio_preprocess(x, hps)
            log_input_output = (i == 0)

            if hps.prior:
                forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
            else:
                forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

            x_out, loss, _metrics = model(x, **forw_kwargs)

            # Logging
            for key, val in _metrics.items():
                _metrics[key] = val.item()
            _metrics["loss"] = loss = loss.item(
            )  # Make sure to call to free graph

            # Average and log
            for key, val in _metrics.items():
                _metrics[key] = metrics.update(f"test_{key}", val, x.shape[0])

            with t.no_grad():
                if log_input_output:
                    log_inputs(orig_model, logger, x_in, y, x_out, hps)

            logger.set_postfix(**{
                print_key: _metrics[key]
                for print_key, key in _print_keys.items()
            })

    for key, val in _metrics.items():
        logger.add_scalar(f"test_{key}", metrics.avg(f"test_{key}"))

    logger.close_range()
    return {key: metrics.avg(f"test_{key}") for key in _metrics.keys()}
Exemple #2
0
def test_dataset_loader():
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler
    from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
    from jukebox.hparams import setup_hparams
    from jukebox.data.files_dataset import FilesAudioDataset
    hps = setup_hparams("teeny", {})
    hps.sr = 22050  # 44100
    hps.hop_length = 512
    hps.labels = False
    hps.channels = 2
    hps.aug_shift = False
    hps.bs = 2
    hps.nworkers = 2  # Getting 20 it/s with 2 workers, 10 it/s with 1 worker
    print(hps)
    dataset = hps.dataset
    root = hps.root
    from tensorboardX import SummaryWriter
    sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr]
    writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs')
    dataset = FilesAudioDataset(hps)
    print("Length of dataset", len(dataset))

    # Torch Loader
    collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0)
    sampler = DistributedSampler(dataset)
    train_loader = DataLoader(dataset,
                              batch_size=hps.bs,
                              num_workers=hps.nworkers,
                              pin_memory=False,
                              sampler=sampler,
                              drop_last=True,
                              collate_fn=collate_fn)

    dist.barrier()
    sampler.set_epoch(0)
    for i, x in enumerate(tqdm(train_loader)):
        x = x.to('cuda', non_blocking=True)
        for j, aud in enumerate(x):
            writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote in")
        x = audio_preprocess(x, hps)
        x = audio_postprocess(x, hps)
        for j, aud in enumerate(x):
            writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote out")
        dist.barrier()
        break
Exemple #3
0
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps):
    model.train()
    orig_model.train()
    if hps.prior:
        _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss")
    else:
        _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk")

    print_all(data_processor.train_loader)
    print_all(len(data_processor.train_loader))

    for i, x in logger.get_range(data_processor.train_loader):
        if isinstance(x, (tuple, list)):
            x, y = x
        else:
            y = None

        x = x.to('cuda', non_blocking=True)
        if y is not None:
            y = y.to('cuda', non_blocking=True)

        x_in = x = audio_preprocess(x, hps)
        log_input_output = (logger.iters % hps.save_iters == 0)

        if hps.prior:
            forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output)
        else:
            forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

        # Forward
        x_out, loss, _metrics = model(x, **forw_kwargs)

        # Backward
        loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()),
                                                                         scalar=scalar, fp16=hps.fp16, logger=logger)
        # Skip step if overflow
        grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX)
        if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0:
            zero_grad(orig_model)
            continue

        # Step opt. Divide by scale to include clipping and fp16 scaling
        logger.step()
        opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale))
        zero_grad(orig_model)
        lr = hps.lr if shd is None else shd.get_lr()[0]
        if shd is not None: shd.step()
        if ema is not None: ema.step()
        next_lr = hps.lr if shd is None else shd.get_lr()[0]
        finished_training = (next_lr == 0.0)

        # Logging
        for key, val in _metrics.items():
            _metrics[key] = val.item()
        _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph
        _metrics["gn"] = grad_norm
        _metrics["lr"] = lr
        _metrics["lg_loss_scale"] = np.log2(scale)

        # Average and log
        for key, val in _metrics.items():
            _metrics[key] = metrics.update(key, val, x.shape[0])
            if logger.iters % hps.log_steps == 0:
                logger.add_scalar(key, _metrics[key])

        # Save checkpoint
        with t.no_grad():
            if hps.save and (logger.iters % hps.save_iters == 1 or finished_training):
                if ema is not None: ema.swap()
                orig_model.eval()
                name = 'latest' if hps.prior else f'step_{logger.iters}'
                if dist.get_rank() % 8 == 0:
                    save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps)
                orig_model.train()
                if ema is not None: ema.swap()

        # Sample
        with t.no_grad():
            if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training:
                if hps.prior:
                    sample_prior(orig_model, ema, logger, x_in, y, hps)

        # Input/Output
        with t.no_grad():
            if log_input_output:
                log_inputs(orig_model, logger, x_in, y, x_out, hps)

        print("Hey there")
        logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()})
        print("by there")
        if finished_training:
            dist.barrier()
            exit()
    logger.close_range()
    return {key: metrics.avg(key) for key in _metrics.keys()}
Exemple #4
0
def compute_metrics(vqvae, hps, output_folder):
    json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json'
    mp3_dict = load_json(json_path)
    csv = {
        "client": [],
        "media_id": [],
        "external_id": [],
        "s3_key": [],
        'recons_loss_l3': [],
        'spectral_loss_l3': [],
        'multispectral_loss_l3': [],
        'recons_loss_l2': [],
        'spectral_loss_l2': [],
        'multispectral_loss_l2': [],
        'recons_loss_l1': [],
        'spectral_loss_l1': [],
        'multispectral_loss_l1': [],
        'recons_loss': [],
        'spectral_loss': [],
        'multispectral_loss': [],
        'spectral_convergence': [],
        'l2_loss': [],
        'l1_loss': [],
        'linf_loss': [],
        'commit_loss': []
    }

    print("sample_length", vqvae.sample_length)
    print('multipliers', vqvae.multipliers)
    print('x_shape', vqvae.x_shape)
    print('downsamples', vqvae.downsamples)
    print('hop lengths', vqvae.hop_lengths)
    print('z shapes', vqvae.z_shapes)
    print('levels', vqvae.levels)

    print(len(vqvae.encoders))
    # print(vqvae.encoders[0])

    forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps)

    # hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs = 1

    for client_name in mp3_dict:
        if not os.path.exists(os.path.join(output_folder, client_name)):
            os.makedirs(os.path.join(output_folder, client_name))
        if not os.path.exists(os.path.join(output_folder, client_name,
                                           'audio')):
            os.makedirs(os.path.join(output_folder, client_name, 'audio'))
        if not os.path.exists(os.path.join(output_folder, client_name,
                                           'spec')):
            os.makedirs(os.path.join(output_folder, client_name, 'spec'))

        for mp3_metadata in mp3_dict[
                client_name]:  # 'external_id', 'media_id', 'num_samples', 's3_key'
            print(mp3_metadata)
            s3_key = mp3_metadata['s3_key']
            filename = s3_key.split('/')[-1]
            mp3_path = os.path.join(audio_mp3s_folder, s3_key)
            mp3, _ = librosa.core.load(mp3_path, sr=44100)
            librosa.output.write_wav("{}/{}.wav".format(
                os.path.join(output_folder, client_name, 'audio'),
                filename.split('.')[0]),
                                     mp3[:881920],
                                     sr=44100)

            hps.bandwidth = get_bandwidth(mp3, hps)
            inputs = torch.tensor(mp3[:881920]).view(1, -1, 1).to(device)

            mp3_spec = spec(inputs.squeeze().cpu(), hps).numpy()
            # save_spec_plot(mp3_spec, os.path.join(output_folder, client_name, 'spec', filename.split('.')[0] + '.png'),
            #                title=filename.split('.')[0])

            inputs = audio_preprocess(inputs, hps)
            x_outs, loss, _metrics = vqvae(
                inputs, **forw_kwargs,
                return_all_x_outs=True)  # x_outs with top level first

            # print("Loss: {}".format(loss))
            # print("Metrics:", _metrics)

            out_specs = []
            for i, x_out in enumerate(
                    reversed(x_outs)):  # level 0 (bottom) first
                x_out_np = x_out.cpu().squeeze().numpy()
                librosa.output.write_wav("{}/{}_recon{}.wav".format(
                    os.path.join(output_folder, client_name, 'audio'),
                    filename.split('.')[0], i),
                                         x_out_np,
                                         sr=44100)
                x_out_spec = spec(x_out.squeeze().cpu(), hps).numpy()
                out_specs.append(x_out_spec)

            save_spec_plot([mp3_spec] + out_specs,
                           os.path.join(output_folder, client_name, 'spec',
                                        filename.split('.')[0] + '.png'),
                           title=filename.split('.')[0])

            csv['client'].append(client_name)
            csv['media_id'].append(mp3_metadata['media_id'])
            csv['external_id'].append(mp3_metadata['external_id'])
            csv['s3_key'].append(mp3_metadata['s3_key'])
            for k, v in _metrics.items():
                csv[k].append(float(v.squeeze().cpu().numpy()))

            pd.DataFrame(csv).to_csv(os.path.join(output_folder,
                                                  'metrics.csv'))