Ejemplo n.º 1
0
def make_model(model, device, hps, levels=None):
    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length=hps.get('sample_length', 0), sample_length_in_seconds=hps.get('sample_length_in_seconds', 0))), device)
    hps.sample_length = vqvae.sample_length
    if levels is None:
        levels = range(len(priors))
    priors = [make_prior(setup_hparams(priors[level], dict()), vqvae, 'cpu') for level in levels]
    return vqvae, priors
Ejemplo n.º 2
0
def run(**kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    model = "1b_lyrics"
    port = 29500
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams()
    hps.sr = 44100
    hps.n_samples = 1
    hps.name = kwargs["sample_name"]
    chunk_size = 32
    max_batch_size = 16
    hps.levels = 3
    hps.hop_fraction = [.5,.5,.125]

    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
    top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

    sample_length_in_seconds = kwargs["sample_length"]
    hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
    assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

    metas = [dict(
                artist = kwargs["artist"],
                genre = kwargs["genre"],
                total_length = hps.sample_length,
                offset = 0,
                lyrics = kwargs["lyrics"],
            ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]

    sampling_temperature = .98

    lower_batch_size = 16
    max_batch_size = 16
    lower_level_chunk_size = 32
    chunk_size =  32
    sampling_kwargs = [
        dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=sampling_temperature, fp16=True, 
            max_batch_size=max_batch_size, chunk_size=chunk_size)
    ]

    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)

    del top_prior
    empty_cache()
    top_prior=None
    upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
    labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]

    zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
Ejemplo n.º 3
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi

    audio_database = kwargs['audio_database']
    del kwargs['audio_database']

    rank, local_rank, device = setup_dist_from_mpi(port=port)
    print('device:', device)
    print("hps setup")
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = 0
    hps.nworkers = 0
    hps.argv = " ".join(sys.argv)

    hps.bs_sample, hps.nworkers, hps.bs = 1, 1, 1
    hps.bs_sample = hps.nworkers = hps.bs
    print("setting up database")
    # Setup dataset
    data_processor = DataProcessor(hps, audio_database)
    print("midi chunk call")
    for idx in range(8868):
        chunk = data_processor.dataset.get_midi_chunk(idx)
        if chunk.shape != (95, 128):
            print(chunk.shape)
            raise RuntimeError('It failed')
Ejemplo n.º 4
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
Ejemplo n.º 5
0
def test_dataset_loader():
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from torch.utils.data.distributed import DistributedSampler
    from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess
    from jukebox.hparams import setup_hparams
    from jukebox.data.files_dataset import FilesAudioDataset
    hps = setup_hparams("teeny", {})
    hps.sr = 22050  # 44100
    hps.hop_length = 512
    hps.labels = False
    hps.channels = 2
    hps.aug_shift = False
    hps.bs = 2
    hps.nworkers = 2  # Getting 20 it/s with 2 workers, 10 it/s with 1 worker
    print(hps)
    dataset = hps.dataset
    root = hps.root
    from tensorboardX import SummaryWriter
    sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr]
    writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs')
    dataset = FilesAudioDataset(hps)
    print("Length of dataset", len(dataset))

    # Torch Loader
    collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0)
    sampler = DistributedSampler(dataset)
    train_loader = DataLoader(dataset,
                              batch_size=hps.bs,
                              num_workers=hps.nworkers,
                              pin_memory=False,
                              sampler=sampler,
                              drop_last=True,
                              collate_fn=collate_fn)

    dist.barrier()
    sampler.set_epoch(0)
    for i, x in enumerate(tqdm(train_loader)):
        x = x.to('cuda', non_blocking=True)
        for j, aud in enumerate(x):
            writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote in")
        x = audio_preprocess(x, hps)
        x = audio_postprocess(x, hps)
        for j, aud in enumerate(x):
            writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr)
        print("Wrote out")
        dist.barrier()
        break
Ejemplo n.º 6
0
    this_run_slug = sys.argv[1]
else:
    this_run_slug = "co_compose_synth2"

hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/'

meta = pickle.load(open(f'{hps.name}meta.p', "rb"))

hps.sample_length = 1048576 if model == "5b_lyrics" else 786432
chunk_size = 16 if model == "5b_lyrics" else 32
max_batch_size = 3 if model == "5b_lyrics" else 16
hps.hop_fraction = [.5, .5, .125]
hps.levels = 3

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length=hps.sample_length)),
                   device)

metas_1 = meta[0]
metas_2 = meta[1]

print(metas_1)
print(metas_2)

zs = t.load(f'{hps.name}zs-top-level-final.t')

top_prior_raw_to_tokens = 128

hps.sample_length = zs[2].shape[1] * top_prior_raw_to_tokens

upsamplers = [
Ejemplo n.º 7
0
    "sample_length_in_seconds": 20,
    "total_sample_length_in_seconds": 180,
    "sr": 44100,
    "n_samples": 6,
    "hop_fraction": [0.5, 0.5, 0.125]
}

train_options = {"bs": 1, "labels": False}

rank, local_rank, device = setup_dist_from_mpi(port=29500)
print("Device: {}".format(device))

hps = Hyperparams(**sample_options)
hps = setup_hparams(
    "vqvae",
    dict(sample_length=hps.get('sample_length', 0),
         sample_length_in_seconds=hps.get('sample_length_in_seconds', 0),
         labels=False,
         bs=1))
# print(hps)

vqvae = make_vqvae(hps, 'cuda:0')


def compute_metrics(vqvae, hps, output_folder):
    json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json'
    mp3_dict = load_json(json_path)
    csv = {
        "client": [],
        "media_id": [],
        "external_id": [],
        "s3_key": [],