Ejemplo n.º 1
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
Ejemplo n.º 2
0
def run(**kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    model = "1b_lyrics"
    port = 29500
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams()
    hps.sr = 44100
    hps.n_samples = 1
    hps.name = kwargs["sample_name"]
    chunk_size = 32
    max_batch_size = 16
    hps.levels = 3
    hps.hop_fraction = [.5,.5,.125]

    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
    top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

    sample_length_in_seconds = kwargs["sample_length"]
    hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
    assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

    metas = [dict(
                artist = kwargs["artist"],
                genre = kwargs["genre"],
                total_length = hps.sample_length,
                offset = 0,
                lyrics = kwargs["lyrics"],
            ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]

    sampling_temperature = .98

    lower_batch_size = 16
    max_batch_size = 16
    lower_level_chunk_size = 32
    chunk_size =  32
    sampling_kwargs = [
        dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=sampling_temperature, fp16=True, 
            max_batch_size=max_batch_size, chunk_size=chunk_size)
    ]

    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)

    del top_prior
    empty_cache()
    top_prior=None
    upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
    labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]

    zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
Ejemplo n.º 3
0
    this_run_slug = sys.argv[1]
else:
    this_run_slug = "co_compose_synth2"

hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/'

meta = pickle.load(open(f'{hps.name}meta.p', "rb"))

hps.sample_length = 1048576 if model == "5b_lyrics" else 786432
chunk_size = 16 if model == "5b_lyrics" else 32
max_batch_size = 3 if model == "5b_lyrics" else 16
hps.hop_fraction = [.5, .5, .125]
hps.levels = 3

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length=hps.sample_length)),
                   device)

metas_1 = meta[0]
metas_2 = meta[1]

print(metas_1)
print(metas_2)

zs = t.load(f'{hps.name}zs-top-level-final.t')

top_prior_raw_to_tokens = 128

hps.sample_length = zs[2].shape[1] * top_prior_raw_to_tokens

upsamplers = [
    make_prior(setup_hparams(prior, dict()), vqvae, 'cpu')
Ejemplo n.º 4
0
train_options = {"bs": 1, "labels": False}

rank, local_rank, device = setup_dist_from_mpi(port=29500)
print("Device: {}".format(device))

hps = Hyperparams(**sample_options)
hps = setup_hparams(
    "vqvae",
    dict(sample_length=hps.get('sample_length', 0),
         sample_length_in_seconds=hps.get('sample_length_in_seconds', 0),
         labels=False,
         bs=1))
# print(hps)

vqvae = make_vqvae(hps, 'cuda:0')


def compute_metrics(vqvae, hps, output_folder):
    json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json'
    mp3_dict = load_json(json_path)
    csv = {
        "client": [],
        "media_id": [],
        "external_id": [],
        "s3_key": [],
        'recons_loss_l3': [],
        'spectral_loss_l3': [],
        'multispectral_loss_l3': [],
        'recons_loss_l2': [],
        'spectral_loss_l2': [],