Esempio n. 1
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi

    audio_database = kwargs['audio_database']
    del kwargs['audio_database']

    rank, local_rank, device = setup_dist_from_mpi(port=port)
    print('device:', device)
    print("hps setup")
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = 0
    hps.nworkers = 0
    hps.argv = " ".join(sys.argv)

    hps.bs_sample, hps.nworkers, hps.bs = 1, 1, 1
    hps.bs_sample = hps.nworkers = hps.bs
    print("setting up database")
    # Setup dataset
    data_processor = DataProcessor(hps, audio_database)
    print("midi chunk call")
    for idx in range(8868):
        chunk = data_processor.dataset.get_midi_chunk(idx)
        if chunk.shape != (95, 128):
            print(chunk.shape)
            raise RuntimeError('It failed')
Esempio n. 2
0
def run(model, port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)

    with t.no_grad():
        save_outputs(model, device, hps)
Esempio n. 3
0
def run(hps="teeny", port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = setup_hparams(hps, kwargs)
    hps.ngpus = dist.get_world_size()
    hps.argv = " ".join(sys.argv)
    hps.bs_sample = hps.nworkers = hps.bs

    # Setup dataset
    data_processor = DataProcessor(hps)

    # Setup models
    vqvae = make_vqvae(hps, device)
    print_once(f"Parameters VQVAE:{count_parameters(vqvae)}")
    if hps.prior:
        prior = make_prior(hps, vqvae, device)
        print_once(f"Parameters Prior:{count_parameters(prior)}")
        model = prior
    else:
        model = vqvae

    # Setup opt, ema and distributed_model.
    opt, shd, scalar = get_optimizer(model, hps)
    ema = get_ema(model, hps)
    distributed_model = get_ddp(model, hps)

    logger, metrics = init_logging(hps, local_rank, rank)
    logger.iters = model.step

    # Run training, eval, sample
    for epoch in range(hps.curr_epoch, hps.epochs):
        metrics.reset()
        data_processor.set_epoch(epoch)
        if hps.train:
            train_metrics = train(distributed_model, model, opt, shd, scalar,
                                  ema, logger, metrics, data_processor, hps)
            train_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Train', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in train_metrics.items()
                    ]))
            dist.barrier()

        if hps.test:
            if ema: ema.swap()
            test_metrics = evaluate(distributed_model, model, logger, metrics,
                                    data_processor, hps)
            test_metrics['epoch'] = epoch
            if rank == 0:
                print(
                    'Ema', ' '.join([
                        f'{key}: {val:0.4f}'
                        for key, val in test_metrics.items()
                    ]))
            dist.barrier()
            if ema: ema.swap()
        dist.barrier()
Esempio n. 4
0
def run(model, mode='ancestral', codes_file=None, audio_file=None, prompt_length_in_seconds=None, port=29500, **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(dict(mode=mode, codes_file=codes_file, audio_file=audio_file, prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
Esempio n. 5
0
def run(**kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    model = "1b_lyrics"
    port = 29500
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams()
    hps.sr = 44100
    hps.n_samples = 1
    hps.name = kwargs["sample_name"]
    chunk_size = 32
    max_batch_size = 16
    hps.levels = 3
    hps.hop_fraction = [.5,.5,.125]

    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
    top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

    sample_length_in_seconds = kwargs["sample_length"]
    hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
    assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

    metas = [dict(
                artist = kwargs["artist"],
                genre = kwargs["genre"],
                total_length = hps.sample_length,
                offset = 0,
                lyrics = kwargs["lyrics"],
            ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]

    sampling_temperature = .98

    lower_batch_size = 16
    max_batch_size = 16
    lower_level_chunk_size = 32
    chunk_size =  32
    sampling_kwargs = [
        dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=sampling_temperature, fp16=True, 
            max_batch_size=max_batch_size, chunk_size=chunk_size)
    ]

    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)

    del top_prior
    empty_cache()
    top_prior=None
    upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
    labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]

    zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
Esempio n. 6
0
def run(model,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    # Example call:
    # model=5b, name=sample_5b, levels=3, sample_length_in_seconds=20, total_sample_length_in_seconds=180,
    # sr=44100, n_samples=6, hop_fraction=[0.5, 0.5, 0.125]
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
Esempio n. 7
0
def run(model='1b_lyrics',
        config_file=None,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    print('Starting run')

    # Start with an empty configuration
    config_dict = {}

    # Load configuration from yaml file, if specified
    if config_file:
        print("Loading config file %s" % config_file)

        with open(config_file, 'r') as stream:
            config_dict = yaml.safe_load(stream)

        # Update the named parameters with config values (if set)
        model = config_dict.get('model', model)
        mode = config_dict.get('mode', mode)
        codes_file = config_dict.get('codes_file', codes_file)
        audio_file = config_dict.get('audio_file', audio_file)
        prompt_length_in_seconds = config_dict.get('prompt_length_in_seconds',
                                                   prompt_length_in_seconds)
        port = config_dict.get('port', port)

        # Copy the config file to the output directory so
        # it's stored with the generated audio
        print("Copying config file to output directory")
        output_dir = config_dict['name']
        output_config = os.path.join(output_dir, os.path.basename(config_file))
        print("  %s -> %s" % (config_file, output_config))
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)
        shutil.copyfile(config_file, output_config)

    # Update the config from any values specified on the command line.
    # The command line values will overwrite the ones loaded from the config file.
    config_dict.update(kwargs)
    print('Using these configuration values:')
    pprint.pprint(config_dict, indent=2)

    # If we're using a primer file, check that it exists
    if audio_file:
        if not os.path.exists(audio_file):
            print("Trying to use audio_file but it doesn't exist!")
            print("Place your priming audio here:")
            print(audio_file)
            sys.exit(-1)

    hps = Hyperparams(**config_dict)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    with t.no_grad():
        save_samples(model, device, hps, sample_hps)
Esempio n. 8
0
def run(model,
        mode='ancestral',
        codes_file=None,
        audio_file=None,
        prompt_length_in_seconds=None,
        port=29500,
        **kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams(**kwargs)
    sample_hps = Hyperparams(
        dict(mode=mode,
             codes_file=codes_file,
             audio_file=audio_file,
             prompt_length_in_seconds=prompt_length_in_seconds))

    offset = 0
    total_l = hps.total_sample_length_in_seconds * hps.sr
    #get metatags from somewhere, fixed or generated
    metatags = [
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="nothing",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="hey come on",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="is this real",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="love is an elephant in a suit",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="you, yes you",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="why do you",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="no",
            total_length=total_l,
            offset=offset,
        ),
        dict(
            artist="Unknown",
            genre="psychedelic",
            lyrics="i i i i i",
            total_length=total_l,
            offset=offset,
        ),
    ]
    with t.no_grad():
        save_samples(model, device, hps, sample_hps, metatags)
Esempio n. 9
0
                self.check_cache(bs, n, False)
                y_chunk = self.forward(x_chunk,
                                       encoder_kv=encoder_kv,
                                       sample=True)
                y_chunks.append(y_chunk)
                n += x_chunk.shape[1]
            self.check_cache(bs, n, False)
            y_forw_in_chunks = t.cat(y_chunks, dim=1)

            max_err = t.max(t.abs(y_forw - y_forw_in_chunks))
            assert max_err <= 1e-6, f"Max err is {max_err} {[i for i in range(l) if t.max(t.abs(y_forw - y_forw_in_chunks)[:, i, :]) > 1e-6]}"


if __name__ == '__main__':
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    setup_dist_from_mpi(port=29600)
    n_in = 16
    n_ctx = 192
    n_head = 4
    n_depth = 12
    blocks = 16
    for attn_order in [0, 2, 6]:
        encoder_dims = {0: 0, 2: 0, 6: 64}[attn_order]
        prior = Transformer(n_in,
                            n_ctx,
                            n_head,
                            n_depth,
                            mask=True,
                            attn_order=attn_order,
                            encoder_dims=encoder_dims,
                            blocks=blocks).cuda()
Esempio n. 10
0
    def __init__(self, levels):
        super().__init__()
        self.level_blocks = nn.ModuleList()
        self.levels = levels
        for level in range(levels):
            self.level_blocks.append(NoBottleneckBlock())

    def encode(self, xs):
        return xs

    def decode(self, zs, start_level=0, end_level=None):
        if end_level is None:
            end_level = self.levels
        return zs

    def forward(self, xs):
        zero = t.zeros(()).cuda()
        commit_losses = [zero for _ in range(self.levels)]
        metrics = [
            dict(entropy=zero, usage=zero, used_curr=zero, pn=zero, dk=zero)
            for _ in range(self.levels)
        ]
        return xs, xs, commit_losses, metrics


if __name__ == '__main__':
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    rank, local_rank, device = setup_dist_from_mpi(port=29600)
    bottleneck = Bottleneck(256, 64, 0.99, 2).to(device)
    bottleneck.check()
Esempio n. 11
0
def run(mode='ancestral',
        audio_file=None,
        prompt_length_in_seconds=12.0,
        port=29500):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    from jukebox.utils import queue
    # setup distributed communications
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    while True:
        # connect to db
        db, cur = queue.connectdb()
        offset = 0
        # get the next job
        job = queue.get_next_job(cur)
        if job:
            print(job)
            job_id = job['job_id']
            kw = dict()
            kw['sr'] = 44100
            kw['n_samples'] = 3
            kw['hop_fraction'] = (0.5, 0.5, 0.25)
            kw['model'] = '5b_lyrics'
            kw['levels'] = 3
            kw['sample_length_in_seconds'] = int(job['params']['length'])
            kw['total_sample_length_in_seconds'] = int(job['params']['length'])
            kw['n_samples'] = 15 if '5b_lyrics' == job['params'][
                'model'] else 16
            kw['job_id'] = job_id
            kw['name'] = job['params']['name']
            hps = Hyperparams(kw)
            # artist, lyrics, genre
            metas = Hyperparams(
                dict(
                    artist=job['params']['artist'],
                    genre=job['params']['genre'],
                    lyrics=job['params']['lyrics'],
                    total_length=job['params']['length'] *
                    kw['sr'],  # remove hardcoded sr
                    offset=offset))
            print(hps)
            sample_hps = Hyperparams(
                dict(mode=mode,
                     audio_file=audio_file,
                     prompt_length_in_seconds=prompt_length_in_seconds))
            # Lock the job
            queue.lock(cur, job_id)
            # Start the job
            queue.update_status(cur, job_id, "top_started")
            # Log the URL
            curl = subprocess.Popen(os.path.expanduser('./get_ip.sh'),
                                    stdout=subprocess.PIPE)
            ip, _ = curl.communicate()  # (ip, error)
            url = "http://{}/jukebox/{}_{}/".format(ip.decode().strip(),
                                                    job_id,
                                                    job['params']['name'])

            queue.log(
                cur, job_id,
                "URL: http://{}/jukebox/{}_{}/".format(ip.decode().strip(),
                                                       job_id,
                                                       job['params']['name']))
            # close db connection to avoid timeout error after sampling
            queue.closedb(db)
            # Run the full generating script here
            with t.no_grad():
                save_samples(job['params']['model'], device, hps, sample_hps,
                             [metas])
            # FINISH
            # open fresh db connection
            db, cur = queue.connectdb()
            # update status
            queue.update_status(cur, job_id, "upsampling_done")
            queue.closedb(db)
        else:
            # pause the program for a minute and check back for new jobs
            print('Zzz...')
            time.sleep(60)