Exemple #1
0
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
Exemple #2
0
def run(**kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    model = "1b_lyrics"
    port = 29500
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams()
    hps.sr = 44100
    hps.n_samples = 1
    hps.name = kwargs["sample_name"]
    chunk_size = 32
    max_batch_size = 16
    hps.levels = 3
    hps.hop_fraction = [.5,.5,.125]

    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
    top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

    sample_length_in_seconds = kwargs["sample_length"]
    hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
    assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

    metas = [dict(
                artist = kwargs["artist"],
                genre = kwargs["genre"],
                total_length = hps.sample_length,
                offset = 0,
                lyrics = kwargs["lyrics"],
            ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]

    sampling_temperature = .98

    lower_batch_size = 16
    max_batch_size = 16
    lower_level_chunk_size = 32
    chunk_size =  32
    sampling_kwargs = [
        dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=sampling_temperature, fp16=True, 
            max_batch_size=max_batch_size, chunk_size=chunk_size)
    ]

    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)

    del top_prior
    empty_cache()
    top_prior=None
    upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
    labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]

    zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
Exemple #3
0
def run(model, port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)

    with t.no_grad():
        save_outputs(model, device, hps)
Exemple #4
0
def run(model,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    # Example call:
    # model=5b, name=sample_5b, levels=3, sample_length_in_seconds=20, total_sample_length_in_seconds=180,
    # sr=44100, n_samples=6, hop_fraction=[0.5, 0.5, 0.125]
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
def check_sample():
    n_ctx = 8192
    n_samples = 4
    levels = 3
    priors = [DummyPrior(n_ctx, level, levels) for level in range(levels)]
    max_total_length, offset, sample_length = 4134368, 0, n_ctx*8*4*4
    y = t.tensor([max_total_length, offset, sample_length, 10, 1, -1, -1, -1, -1], dtype=t.long, device='cuda').view(1, 9).repeat(n_samples, 1)
    labels = [dict(y=y, info=[[]*n_samples]) for level in range(levels)]
    hps = Hyperparams({
        'levels': 3,
        'sample_length': sample_length,
        'n_segment': 2,
        'n_ctx': n_ctx,
        'n_tokens': 0,
        'hop_lengths': [n_ctx//2, n_ctx//2, n_ctx//8],
        'n_samples': n_samples,
        'use_tokens': False
    })
    test_ancestral_sample(labels, priors, hps)
    test_primed_sample(labels, priors, hps)
Exemple #6
0
def run(model='1b_lyrics',
        config_file=None,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    print('Starting run')

    # Start with an empty configuration
    config_dict = {}

    # Load configuration from yaml file, if specified
    if config_file:
        print("Loading config file %s" % config_file)

        with open(config_file, 'r') as stream:
            config_dict = yaml.safe_load(stream)

        # Update the named parameters with config values (if set)
        model = config_dict.get('model', model)
        mode = config_dict.get('mode', mode)
        codes_file = config_dict.get('codes_file', codes_file)
        audio_file = config_dict.get('audio_file', audio_file)
        prompt_length_in_seconds = config_dict.get('prompt_length_in_seconds',
                                                   prompt_length_in_seconds)
        port = config_dict.get('port', port)

        # Copy the config file to the output directory so
        # it's stored with the generated audio
        print("Copying config file to output directory")
        output_dir = config_dict['name']
        output_config = os.path.join(output_dir, os.path.basename(config_file))
        print("  %s -> %s" % (config_file, output_config))
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        shutil.copyfile(config_file, output_config)

    # Update the config from any values specified on the command line.
    # The command line values will overwrite the ones loaded from the config file.
    config_dict.update(kwargs)
    print('Using these configuration values:')
    pprint.pprint(config_dict, indent=2)

    # If we're using a primer file, check that it exists
    if audio_file:
        if not os.path.exists(audio_file):
            print("Trying to use audio_file but it doesn't exist!")
            print("Place your priming audio here:")
            print(audio_file)
            sys.exit(-1)

    hps = Hyperparams(**config_dict)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
Exemple #7
0
def run(model,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    offset = 0
    total_l = hps.total_sample_length_in_seconds * hps.sr
    #get metatags from somewhere, fixed or generated
    metatags = [
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="nothing",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="hey come on",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="is this real",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="love is an elephant in a suit",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="you, yes you",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="why do you",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="no",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="i i i i i",
            total_length=total_l,
            offset=offset,
        ),
    ]
    with t.no_grad():
        save_samples(model, device, hps, sample_hps, metatags)
Exemple #8
0
def run(mode='ancestral',
        audio_file=None,
        prompt_length_in_seconds=12.0,
        port=29500):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    from jukebox.utils import queue
    # setup distributed communications
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    while True:
        # connect to db
        db, cur = queue.connectdb()
        offset = 0
        # get the next job
        job = queue.get_next_job(cur)
        if job:
            print(job)
            job_id = job['job_id']
            kw = dict()
            kw['sr'] = 44100
            kw['n_samples'] = 3
            kw['hop_fraction'] = (0.5, 0.5, 0.25)
            kw['model'] = '5b_lyrics'
            kw['levels'] = 3
            kw['sample_length_in_seconds'] = int(job['params']['length'])
            kw['total_sample_length_in_seconds'] = int(job['params']['length'])
            kw['n_samples'] = 15 if '5b_lyrics' == job['params'][
                'model'] else 16
            kw['job_id'] = job_id
            kw['name'] = job['params']['name']
            hps = Hyperparams(kw)
            # artist, lyrics, genre
            metas = Hyperparams(
                dict(
                    artist=job['params']['artist'],
                    genre=job['params']['genre'],
                    lyrics=job['params']['lyrics'],
                    total_length=job['params']['length'] *
                    kw['sr'],  # remove hardcoded sr
                    offset=offset))
            print(hps)
            sample_hps = Hyperparams(
                dict(mode=mode,
                     audio_file=audio_file,
                     prompt_length_in_seconds=prompt_length_in_seconds))
            # Lock the job
            queue.lock(cur, job_id)
            # Start the job
            queue.update_status(cur, job_id, "top_started")
            # Log the URL
            curl = subprocess.Popen(os.path.expanduser('./get_ip.sh'),
                                    stdout=subprocess.PIPE)
            ip, _ = curl.communicate()  # (ip, error)
            url = "http://{}/jukebox/{}_{}/".format(ip.decode().strip(),
                                                    job_id,
                                                    job['params']['name'])

            queue.log(
                cur, job_id,
                "URL: http://{}/jukebox/{}_{}/".format(ip.decode().strip(),
                                                       job_id,
                                                       job['params']['name']))
            # close db connection to avoid timeout error after sampling
            queue.closedb(db)
            # Run the full generating script here
            with t.no_grad():
                save_samples(job['params']['model'], device, hps, sample_hps,
                             [metas])
            # FINISH
            # open fresh db connection
            db, cur = queue.connectdb()
            # update status
            queue.update_status(cur, job_id, "upsampling_done")
            queue.closedb(db)
        else:
            # pause the program for a minute and check back for new jobs
            print('Zzz...')
            time.sleep(60)
Exemple #9
0
import pickle
import random
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample, \
                           load_prompts
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache

port = random.randint(10000, 20000)
rank, local_rank, device = setup_dist_from_mpi(port=port)

model = "5b_lyrics"  # or "1b_lyrics"
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 3 if model == '5b_lyrics' else 16
# Specifies the directory to save the sample in.
# We set this to the Google Drive mount point.

if len(sys.argv) > 1:
    this_run_slug = sys.argv[1]
else:
    this_run_slug = "co_compose_synth2"

hps.name = '/home/robin/google-drive/samples/' + this_run_slug + '/'

meta = pickle.load(open(f'{hps.name}meta.p', "rb"))

hps.sample_length = 1048576 if model == "5b_lyrics" else 786432
Exemple #10
0
sample_options = {
    "name": "sample_5b",
    "levels": 3,
    "sample_length_in_seconds": 20,
    "total_sample_length_in_seconds": 180,
    "sr": 44100,
    "n_samples": 6,
    "hop_fraction": [0.5, 0.5, 0.125]
}

train_options = {"bs": 1, "labels": False}

rank, local_rank, device = setup_dist_from_mpi(port=29500)
print("Device: {}".format(device))

hps = Hyperparams(**sample_options)
hps = setup_hparams(
    "vqvae",
    dict(sample_length=hps.get('sample_length', 0),
         sample_length_in_seconds=hps.get('sample_length_in_seconds', 0),
         labels=False,
         bs=1))
# print(hps)

vqvae = make_vqvae(hps, 'cuda:0')


def compute_metrics(vqvae, hps, output_folder):
    json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json'
    mp3_dict = load_json(json_path)
    csv = {