Example #1
0
def save_samples(model, device, hps, sample_hps, metas: list):
    """Generate and save samples, alignment, and webpage for visualization."""
    print(hps)
    from jukebox.lyricdict import poems, gpt_2_lyrics
    vqvae, priors = make_model(model, device, hps)

    assert hps.sample_length // priors[-2].raw_to_tokens >= priors[
        -2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"
    assert isinstance(metas, list)
    total_length = hps.total_sample_length_in_seconds * hps.sr
    offset = 0
    while len(metas) < hps.n_samples:
        metas.extend(metas)
    metas = metas[:hps.n_samples]

    labels = [
        prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors
    ]
    for label in labels:
        assert label['y'].shape[0] == hps.n_samples

    lower_level_chunk_size = 32
    lower_level_max_batch_size = 16
    if model == '1b_lyrics':
        chunk_size = 32
        max_batch_size = 16
    else:
        chunk_size = 16
        max_batch_size = 3
    sampling_kwargs = [
        dict(temp=0.99,
             fp16=True,
             chunk_size=lower_level_chunk_size,
             max_batch_size=lower_level_max_batch_size),
        dict(temp=0.99,
             fp16=True,
             chunk_size=lower_level_chunk_size,
             max_batch_size=lower_level_max_batch_size),
        dict(temp=0.99,
             fp16=True,
             chunk_size=chunk_size,
             max_batch_size=max_batch_size)
    ]

    if sample_hps.mode == 'ancestral':
        ancestral_sample(labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'primed':
        assert sample_hps.audio_file is not None
        audio_files = sample_hps.audio_file.split(',')
        top_raw_to_tokens = priors[-1].raw_to_tokens
        duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) //
                    top_raw_to_tokens) * top_raw_to_tokens
        x = load_prompts(audio_files, duration, hps)
        primed_sample(x, labels, sampling_kwargs, priors, hps)
    else:
        raise ValueError(f'Unknown sample mode {sample_hps.mode}.')
def save_alignment(model, device, hps):
    print(hps)
    vqvae, priors = make_model(model, device, hps, levels=[-1])

    logdir = f"{hps.logdir}/level_{0}"
    data = t.load(f"{logdir}/data.pth.tar")
    if model == '1b_lyrics':
        fp16 = False
    else:
        fp16 = True

    data['alignments'] = get_alignment(data['x'], data['zs'],
                                       data['labels'][-1], priors[-1], fp16,
                                       hps)
    t.save(data, f"{logdir}/data_align.pth.tar")
    save_html(logdir, data['x'], data['zs'], data['labels'][-1],
              data['alignments'], hps)
Example #3
0
def save_samples(model, device, hps, sample_hps):
    print(hps)
    from jukebox.lyricdict import poems, gpt_2_lyrics
    vqvae, priors = make_model(model, device, hps)

    wandb_utils.watch_model(vqvae)

    assert hps.sample_length//priors[-2].raw_to_tokens >= priors[-2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"

    total_length = hps.total_sample_length_in_seconds * hps.sr
    offset = 0

    # Set artist/genre/lyrics for your samples here!
    # We used different label sets in our models, but you can write the human friendly names here and we'll map them under the hood for each model.
    # For the 5b/5b_lyrics model and the upsamplers, labeller will look up artist and genres in v2 set. (after lowercasing, removing non-alphanumerics and collapsing whitespaces to _).
    # For the 1b_lyrics top level, labeller will look up artist and genres in v3 set (after lowercasing).
    metas = [dict(artist = "Alan Jackson",
                  genre = "Country",
                  lyrics = poems['ozymandias'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Joe Bonamassa",
                  genre="Blues Rock",
                  lyrics=gpt_2_lyrics['hottub'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Frank Sinatra",
                  genre="Classic Pop",
                  lyrics=gpt_2_lyrics['alone'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Ella Fitzgerald",
                  genre="Jazz",
                  lyrics=gpt_2_lyrics['count'],
                  total_length=total_length,
                  offset=offset,
                  ),
             dict(artist="Céline Dion",
                  genre="Pop",
                  lyrics=gpt_2_lyrics['darkness'],
                  total_length=total_length,
                  offset=offset,
                  ),
             ]
    while len(metas) < hps.n_samples:
        metas.extend(metas)
    metas = metas[:hps.n_samples]

    labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors]
    for label in labels:
        assert label['y'].shape[0] == hps.n_samples

    lower_level_chunk_size = 32
    lower_level_max_batch_size = 16
    if model == '1b_lyrics':
        chunk_size = 32
        max_batch_size = 16
    else:
        chunk_size = 16
        max_batch_size = 3
    sampling_kwargs = [dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
                       dict(temp=0.99, fp16=True, chunk_size=lower_level_chunk_size, max_batch_size=lower_level_max_batch_size),
                       dict(temp=0.99, fp16=True, chunk_size=chunk_size, max_batch_size=max_batch_size)]

    if sample_hps.mode == 'ancestral':
        ancestral_sample(labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode in ['continue', 'upsample']:
        assert sample_hps.codes_file is not None
        top_raw_to_tokens = priors[-1].raw_to_tokens
        if sample_hps.prompt_length_in_seconds is not None:
            duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
        else:
            duration = None
        zs = load_codes(sample_hps.codes_file, duration, priors, hps)
        if sample_hps.mode == 'continue':
            continue_sample(zs, labels, sampling_kwargs, priors, hps)
        elif sample_hps.mode == 'upsample':
            upsample(zs, labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'primed':
        assert sample_hps.audio_file is not None
        assert sample_hps.prompt_length_in_seconds is not None
        audio_files = sample_hps.audio_file.split(',')
        top_raw_to_tokens = priors[-1].raw_to_tokens
        duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) // top_raw_to_tokens) * top_raw_to_tokens
        x = load_prompts(audio_files, duration, hps)
        primed_sample(x, labels, sampling_kwargs, priors, hps)
    else:
        raise ValueError(f'Unknown sample mode {sample_hps.mode}.')
Example #4
0
def save_samples(model, device, hps, sample_hps):
    print(hps)
    from jukebox.lyricdict import poems, gpt_2_lyrics
    vqvae, priors = make_model(model, device, hps)

    assert hps.sample_length // priors[-2].raw_to_tokens >= priors[
        -2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"

    total_length = hps.total_sample_length_in_seconds * hps.sr
    offset = 0
    metas = [
        dict(
            artist="Alan Jackson",
            genre="Country",
            lyrics=poems['ozymandias'],
            total_length=total_length,
            offset=offset,
        ),
        dict(
            artist="Joe Bonamassa",
            genre="Blues Rock",
            lyrics=gpt_2_lyrics['hottub'],
            total_length=total_length,
            offset=offset,
        ),
        dict(
            artist="Frank Sinatra",
            genre="Classic Pop",
            lyrics=gpt_2_lyrics['alone'],
            total_length=total_length,
            offset=offset,
        ),
        dict(
            artist="Ella Fitzgerald",
            genre="Jazz",
            lyrics=gpt_2_lyrics['count'],
            total_length=total_length,
            offset=offset,
        ),
        dict(
            artist="Celine Dion",
            genre="Pop",
            lyrics=gpt_2_lyrics['darkness'],
            total_length=total_length,
            offset=offset,
        ),
    ]
    while len(metas) < hps.n_samples:
        metas.extend(metas)
    metas = metas[:hps.n_samples]

    labels = [
        prior.labeller.get_batch_labels(metas, 'cuda') for prior in priors
    ]
    for label in labels:
        assert label['y'].shape[0] == hps.n_samples

    lower_level_chunk_size = 32
    lower_level_max_batch_size = 16
    if model == '1b_lyrics':
        chunk_size = 32
        max_batch_size = 16
    else:
        chunk_size = 16
        max_batch_size = 3
    sampling_kwargs = [
        dict(temp=0.99,
             fp16=True,
             chunk_size=lower_level_chunk_size,
             max_batch_size=lower_level_max_batch_size),
        dict(temp=0.99,
             fp16=True,
             chunk_size=lower_level_chunk_size,
             max_batch_size=lower_level_max_batch_size),
        dict(temp=0.99,
             fp16=True,
             chunk_size=chunk_size,
             max_batch_size=max_batch_size)
    ]

    if sample_hps.mode == 'ancestral':
        ancestral_sample(labels, sampling_kwargs, priors, hps)
    elif sample_hps.mode == 'primed':
        assert sample_hps.audio_file is not None
        audio_files = sample_hps.audio_file.split(',')
        top_raw_to_tokens = priors[-1].raw_to_tokens
        duration = (int(sample_hps.prompt_length_in_seconds * hps.sr) //
                    top_raw_to_tokens) * top_raw_to_tokens
        x = load_prompts(audio_files, duration, hps)
        primed_sample(x, labels, sampling_kwargs, priors, hps)
    else:
        raise ValueError(f'Unknown sample mode {sample_hps.mode}.')
Example #5
0
def save_samples(model, device, hps, sample_hps):
    print(hps)
    from jukebox.lyricdict import poems, gpt_2_lyrics
    vqvae, priors = make_model(model, device, hps)

    assert hps.sample_length // priors[-2].raw_to_tokens >= priors[
        -2].n_ctx, f"Upsampling needs atleast one ctx in get_z_conds. Please choose a longer sample length"

    total_length = hps.total_sample_length_in_seconds * hps.sr
    offset = 0

    # Set artist/genre/lyrics for your samples here!
    # We used different label sets in our models, but you can write the human friendly names here and we'll map them under the hood for each model.
    # For the 5b/5b_lyrics model and the upsamplers, labeller will look up artist and genres in v2 set. (after lowercasing, removing non-alphanumerics and collapsing whitespaces to _).
    # For the 1b_lyrics top level, labeller will look up artist and genres in v3 set (after lowercasing).

    artists = ["Pink Floyd", "Beat Farmers"]
    genres = ["nintendocore", "worship", "glitch", "synthpop‎", "latin jazz"]
    lyrics = [
        "Almost anybody can learn to think or believe or know, but not a single human being can be taught to feel. Why? Because whenever you think or you believe or you know, you’re a lot of other people: but the moment you feel, you’re nobody-but-yourself."
    ]
    temps = [0.98]

    for artist in artists:
        for genre in genres:
            for lyric in lyrics:
                for temp in temps:
                    hps.name = "_".join(artist + genres + lyrics[0])
                    metas = [
                        dict(
                            artist=artist,
                            genre=genre,
                            lyrics=lyrics,
                            total_length=total_length,
                            offset=offset,
                        ),
                    ]
                    while len(metas) < hps.n_samples:
                        metas.extend(metas)
                    metas = metas[:hps.n_samples]
                    print(metas)

                    labels = [
                        prior.labeller.get_batch_labels(metas, 'cuda')
                        for prior in priors
                    ]
                    for label in labels:
                        assert label['y'].shape[0] == hps.n_samples

                    lower_level_chunk_size = 32
                    lower_level_max_batch_size = 16
                    lower_level_max_batch_size = 1
                    if model == '1b_lyrics':
                        chunk_size = 32
                        max_batch_size = 16
                    else:
                        chunk_size = 16
                        max_batch_size = 3
                        max_batch_size = 1
                    sampling_kwargs = [
                        dict(temp=0.99,
                             fp16=True,
                             chunk_size=lower_level_chunk_size,
                             max_batch_size=lower_level_max_batch_size),
                        dict(temp=0.99,
                             fp16=True,
                             chunk_size=lower_level_chunk_size,
                             max_batch_size=lower_level_max_batch_size),
                        dict(temp=temp,
                             fp16=True,
                             chunk_size=chunk_size,
                             max_batch_size=max_batch_size)
                    ]

                    if sample_hps.mode == 'ancestral':
                        ancestral_sample(labels, sampling_kwargs, priors, hps)
                    elif sample_hps.mode in ['continue', 'upsample']:
                        assert sample_hps.codes_file is not None
                        top_raw_to_tokens = priors[-1].raw_to_tokens
                        if sample_hps.prompt_length_in_seconds is not None:
                            duration = (int(sample_hps.prompt_length_in_seconds
                                            * hps.sr) //
                                        top_raw_to_tokens) * top_raw_to_tokens
                        else:
                            duration = None
                        zs = load_codes(sample_hps.codes_file, duration,
                                        priors, hps)
                        if sample_hps.mode == 'continue':
                            continue_sample(zs, labels, sampling_kwargs,
                                            priors, hps)
                        elif sample_hps.mode == 'upsample':
                            upsample(zs, labels, sampling_kwargs, priors, hps)
                    elif sample_hps.mode == 'primed':
                        assert sample_hps.audio_file is not None
                        assert sample_hps.prompt_length_in_seconds is not None
                        audio_files = sample_hps.audio_file.split(',')
                        top_raw_to_tokens = priors[-1].raw_to_tokens
                        duration = (int(
                            sample_hps.prompt_length_in_seconds * hps.sr) //
                                    top_raw_to_tokens) * top_raw_to_tokens
                        x = load_prompts(audio_files, duration, hps)
                        primed_sample(x, labels, sampling_kwargs, priors, hps)
                    else:
                        raise ValueError(
                            f'Unknown sample mode {sample_hps.mode}.')