コード例 #1
0
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start,
                         hps):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:, start:end].to(prior.device)

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(
        f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens"
    )

    if new_tokens <= 0:
        # Nothing new to sample
        return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    if z_conds != None:
        for k in range(len(z_conds)):
            z_conds[k] = z_conds[k].to(prior.device)

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']

    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
        z_samples_i = prior.sample(n_samples=z_i.shape[0],
                                   z=z_i,
                                   z_conds=z_conds_i,
                                   y=y_i,
                                   **sampling_kwargs)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:, -new_tokens:].cpu()
    del z
    del y
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs
コード例 #2
0
ファイル: sample.py プロジェクト: mxcjchas/jukebox
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length//prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
        zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

        logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0:
            alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
        save_html(logdir, x, zs, labels[-1], alignments, hps)
    return zs
コード例 #3
0
ファイル: sample.py プロジェクト: SirLarion/lmariAI
def run(**kwargs):
    from jukebox.utils.dist_utils import setup_dist_from_mpi
    model = "1b_lyrics"
    port = 29500
    rank, local_rank, device = setup_dist_from_mpi(port=port)
    hps = Hyperparams()
    hps.sr = 44100
    hps.n_samples = 1
    hps.name = kwargs["sample_name"]
    chunk_size = 32
    max_batch_size = 16
    hps.levels = 3
    hps.hop_fraction = [.5,.5,.125]

    vqvae, *priors = MODELS[model]
    vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
    top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)

    sample_length_in_seconds = kwargs["sample_length"]
    hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
    assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'

    metas = [dict(
                artist = kwargs["artist"],
                genre = kwargs["genre"],
                total_length = hps.sample_length,
                offset = 0,
                lyrics = kwargs["lyrics"],
            ),
    ] * hps.n_samples

    labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]

    sampling_temperature = .98

    lower_batch_size = 16
    max_batch_size = 16
    lower_level_chunk_size = 32
    chunk_size =  32
    sampling_kwargs = [
        dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
            chunk_size=lower_level_chunk_size),
        dict(temp=sampling_temperature, fp16=True, 
            max_batch_size=max_batch_size, chunk_size=chunk_size)
    ]

    zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))]
    zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps)

    del top_prior
    empty_cache()
    top_prior=None
    upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
    labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]

    zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
コード例 #4
0
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length // prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level] * prior.n_ctx)
        zs = sample_level(zs, labels[level], sampling_kwargs[level], level,
                          prior, total_length, hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:],
                         start_level=level,
                         bs_chunks=zs[level].shape[0])

        if dist.get_world_size() > 1:
            logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
        else:
            logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(
            dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x),
            f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        #if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller):
        #alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
        lepath = hps.name
        if level == 2:
            for filex in glob(os.path.join(lepath + '/level_2', 'item_*.wav')):
                os.rename(filex,
                          filex.replace('item_',
                                        lepath.split('/')[-1] + '-'))
        if level == 1:
            for filex in glob(os.path.join(lepath + '/level_1', 'item_*.wav')):
                os.rename(
                    filex,
                    filex.replace('item_',
                                  lepath.split('/')[-1] + '-L1-'))
        if level == 0:
            for filex in glob(os.path.join(lepath + '/level_0', 'item_*.wav')):
                os.rename(
                    filex,
                    filex.replace('item_',
                                  lepath.split('/')[-1] + '-L0-'))
        #save_html(logdir, x, zs, labels[-1], alignments, hps)
    return zs
コード例 #5
0
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:,start:end]

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens")

    if new_tokens <= 0:
        # Nothing new to sample
        return zs
    
    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']


    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
        midi_path = r"C:\Users\Yousef\Desktop\UNiz\MidiDataset\Cleaned\acdc\Big Balls.mid"
        midi = load_sample_midi(midi_path)
        z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs, midi=midi)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:,-new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs
コード例 #6
0
ファイル: sample.py プロジェクト: fedden/jukebox
def _sample(zs, labels_1, labels_2, sampling_kwargs, priors, sample_levels,
            hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length // prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level] * prior.n_ctx)
        zs = sample_level(zs, labels_1[level], labels_2[level],
                          sampling_kwargs[level], level, prior, total_length,
                          hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:],
                         start_level=level,
                         bs_chunks=zs[level].shape[0])

        if dist.get_world_size() > 1:
            logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
        else:
            logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(
            dict(zs=zs, labels=labels_1, sampling_kwargs=sampling_kwargs, x=x),
            f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        if alignments is None and priors[
                -1] is not None and priors[-1].n_tokens > 0 and not isinstance(
                    priors[-1].labeller, EmptyLabeller):
            try:
                labels_1[-1], priors[-1], sampling_kwargs[-1]['fp16']
            except:
                import ipdb
                ipdb.set_trace()
            alignments = get_alignment(x, zs, labels_1[-1], priors[-1],
                                       sampling_kwargs[-1]['fp16'], hps)
        # don't care
        # save_html(logdir, x, zs, labels_1[-1], alignments, hps)
    return zs
コード例 #7
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
        starts = get_starts(total_length, prior.n_ctx, hop_length)
        counterr = 0
        x = None
        for start in starts:
            counterr += 1
            datea = datetime.now()
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
            if newtosample and counterr < len(starts):
                del x
                x = None
                prior.cpu()
                empty_cache()
                x = prior.decode(zs[level:],
                                 start_level=level,
                                 bs_chunks=zs[level].shape[0])
                logdir = f"{hps.name}/level_{level}"
                if not os.path.exists(logdir):
                    os.makedirs(logdir)
                t.save(
                    dict(zs=zs,
                         labels=labels,
                         sampling_kwargs=sampling_kwargs,
                         x=x), f"{logdir}/data.pth.tar")
                save_wav(logdir, x, hps.sr)
                del x
                prior.cuda()
                empty_cache()
                x = None
            dateb = datetime.now()
            timex = ((dateb - datea).total_seconds() / 60.0) * (len(starts) -
                                                                counterr)
            print(f"Step " + colored(counterr, 'blue') + "/" +
                  colored(len(starts), 'red') + " ~ New to Sample: " +
                  str(newtosample) + " ~ estimated remaining minutes: " +
                  (colored('???', 'yellow'),
                   colored(timex, 'magenta'))[counterr > 1 and newtosample])
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
コード例 #8
0
def get_alignment(x, zs, labels, prior, fp16, hps):
    level = hps.levels - 1  # Top level used
    n_ctx, n_tokens = prior.n_ctx, prior.n_tokens
    z = zs[level]
    bs, total_length = z.shape[0], z.shape[1]
    if total_length < n_ctx:
        padding_length = n_ctx - total_length
        z = t.cat([
            z,
            t.zeros(bs, n_ctx - total_length, dtype=z.dtype, device=z.device)
        ],
                  dim=1)
        total_length = z.shape[1]
    else:
        padding_length = 0

    hop_length = int(hps.hop_fraction[level] * prior.n_ctx)
    n_head = prior.prior.transformer.n_head
    alignment_head, alignment_layer = prior.alignment_head, prior.alignment_layer
    attn_layers = set([alignment_layer])
    alignment_hops = {}
    indices_hops = {}

    prior.cuda()
    empty_cache()
    for start in get_starts(total_length, n_ctx, hop_length):
        end = start + n_ctx

        # set y offset, sample_length and lyrics tokens
        y, indices_hop = prior.get_y(labels, start, get_indices=True)
        assert len(indices_hop) == bs
        for indices in indices_hop:
            assert len(indices) == n_tokens

        z_bs = t.chunk(z, bs, dim=0)
        y_bs = t.chunk(y, bs, dim=0)
        w_hops = []
        for z_i, y_i in zip(z_bs, y_bs):
            w_hop = prior.z_forward(z_i[:, start:end], [],
                                    y_i,
                                    fp16=fp16,
                                    get_attn_weights=attn_layers)
            assert len(w_hop) == 1
            w_hops.append(w_hop[0][:, alignment_head])
            del w_hop
        w = t.cat(w_hops, dim=0)
        del w_hops
        assert_shape(w, (bs, n_ctx, n_tokens))
        alignment_hop = w.float().cpu().numpy()
        assert_shape(alignment_hop, (bs, n_ctx, n_tokens))
        del w

        # alignment_hop has shape (bs, n_ctx, n_tokens)
        # indices_hop is a list of len=bs, each entry of len hps.n_tokens
        indices_hops[start] = indices_hop
        alignment_hops[start] = alignment_hop
    prior.cpu()
    empty_cache()

    # Combine attn for each hop into attn for full range
    # Use indices to place them into correct place for corresponding source tokens
    alignments = []
    for item in range(bs):
        # Note each item has different length lyrics
        full_tokens = labels['info'][item]['full_tokens']
        alignment = np.zeros((total_length, len(full_tokens) + 1))
        for start in reversed(get_starts(total_length, n_ctx, hop_length)):
            end = start + n_ctx
            alignment_hop = alignment_hops[start][item]
            indices = indices_hops[start][item]
            assert len(indices) == n_tokens
            assert alignment_hop.shape == (n_ctx, n_tokens)
            alignment[start:end, indices] = alignment_hop
        alignment = alignment[:total_length - padding_length, :
                              -1]  # remove token padding, and last lyric index
        alignments.append(alignment)
    return alignments
コード例 #9
0
    def primed_sample(self,
                      n_samples,
                      x,
                      memory,
                      x_cond=None,
                      y_cond=None,
                      encoder_kv=None,
                      fp16=False,
                      temp=1.0,
                      top_k=0,
                      top_p=0.0,
                      get_preds=False,
                      chunk_size=None,
                      sample_tokens=None):
        assert self.training == False

        if sample_tokens is None: sample_tokens = self.input_dims
        # Preprocess.
        with t.no_grad():
            x = self.preprocess(x)
        assert isinstance(x, t.cuda.LongTensor)
        assert (0 <= x).all() and (x < self.bins).all()
        assert x.shape[0] == n_samples
        xs = t.split(x, 1, dim=1)
        xs = list(xs)
        assert len(xs) < sample_tokens

        N, D = n_samples, self.input_dims
        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (
                N, 1, self.width
            ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()

        with t.no_grad():
            if get_preds:
                preds = []

            x = xs[-1]
            assert x.shape == (n_samples, 1)
            empty_cache()
            for sample_t in get_range(range(len(xs), sample_tokens)):
                x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
                #self.transformer.check_cache(n_samples, sample_t, fp16)
                x, memory = self.transformer(x[:, -1, :],
                                             memory)  # Transformer
                x = t.unsqueeze(x, 1)
                if self.add_cond_after_transformer:
                    x = x + cond
                assert x.shape == (n_samples, 1, self.width)
                x = self.x_out(x)  # Predictions
                if get_preds:
                    preds.append(x)
                # Adjust logits
                x = x / temp
                x = filter_logits(x, top_k=top_k, top_p=top_p)
                x = t.distributions.Categorical(
                    logits=x).sample()  # Sample and replace x
                assert x.shape == (n_samples, 1)
                xs.append(x.clone())

            del x
            #self.transformer.del_cache()

            x = t.cat(xs, dim=1)
            if get_preds:
                preds = t.cat(preds, dim=1)
            x = self.postprocess(x, sample_tokens)
        if get_preds:
            return x, preds
        else:
            return x, memory
コード例 #10
0
    def primed_sample(self,
                      n_samples,
                      x,
                      x_cond=None,
                      y_cond=None,
                      encoder_kv=None,
                      fp16=False,
                      temp=1.0,
                      top_k=0,
                      top_p=0.0,
                      get_preds=False,
                      chunk_size=None,
                      sample_tokens=None):
        assert self.training == False

        if sample_tokens is None: sample_tokens = self.input_dims
        # Preprocess.
        with t.no_grad():
            x = self.preprocess(x)
        assert isinstance(x, t.cuda.LongTensor)
        assert (0 <= x).all() and (x < self.bins).all()
        assert x.shape[0] == n_samples
        xs = t.split(x, 1, dim=1)
        xs = list(xs)
        assert len(xs) < sample_tokens

        N, D = n_samples, self.input_dims
        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (
                N, 1, self.width
            ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()

        with t.no_grad():
            if get_preds:
                preds = []

            # Fill up key/value cache for past context by runing forward pass.
            # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage.
            if chunk_size is None:
                chunk_size = len(xs)
            #assert len(xs) % chunk_size == 0, f'expected {len(xs)} to be divisible by {chunk_size}'
            chunk_sizes = split_chunks(len(xs), chunk_size)
            x_primes = []
            start = 0
            x = None
            for current_chunk_size in get_range(chunk_sizes):
                xs_prime, conds_prime = [], []
                for sample_t in range(start, start + current_chunk_size):
                    x_prime, cond_prime = self.get_emb(sample_t, n_samples, x,
                                                       x_cond, y_cond)
                    x = xs[sample_t]
                    xs_prime.append(x_prime)
                    conds_prime.append(cond_prime)
                start = start + current_chunk_size

                x_prime, cond_prime = t.cat(xs_prime,
                                            dim=1), t.cat(conds_prime, dim=1)
                assert x_prime.shape == (n_samples, current_chunk_size,
                                         self.width)
                assert cond_prime.shape == (n_samples, current_chunk_size,
                                            self.width)
                del xs_prime
                del conds_prime
                if not get_preds:
                    del cond_prime
                x_prime = self.transformer(x_prime,
                                           encoder_kv=encoder_kv,
                                           sample=True,
                                           fp16=fp16)

                if get_preds:
                    if self.add_cond_after_transformer:
                        x_prime = x_prime + cond_prime
                    assert x_prime.shape == (n_samples, current_chunk_size,
                                             self.width)
                    del cond_prime
                    x_primes.append(x_prime)
                else:
                    del x_prime

            if get_preds:
                x_prime = t.cat(x_primes, dim=1)
                assert x_prime.shape == (n_samples, len(xs), self.width)
                x_prime = self.x_out(x_prime)  # Predictions
                preds.append(x_prime)

            empty_cache()
            self.transformer.check_cache(n_samples, len(xs), fp16)

            x = xs[-1]
            assert x.shape == (n_samples, 1)
            empty_cache()
            for sample_t in get_range(range(len(xs), sample_tokens)):
                x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
                self.transformer.check_cache(n_samples, sample_t, fp16)
                x = self.transformer(x,
                                     encoder_kv=encoder_kv,
                                     sample=True,
                                     fp16=fp16)  # Transformer
                if self.add_cond_after_transformer:
                    x = x + cond
                assert x.shape == (n_samples, 1, self.width)
                x = self.x_out(x)  # Predictions
                if get_preds:
                    preds.append(x)
                # Adjust logits
                x = x / temp
                x = filter_logits(x, top_k=top_k, top_p=top_p)
                x = t.distributions.Categorical(
                    logits=x).sample()  # Sample and replace x
                assert x.shape == (n_samples, 1)
                xs.append(x.clone())

            del x
            self.transformer.del_cache()

            x = t.cat(xs, dim=1)
            if get_preds:
                preds = t.cat(preds, dim=1)
            x = self.postprocess(x, sample_tokens)
        if get_preds:
            return x, preds
        else:
            return x
コード例 #11
0
ファイル: sample.py プロジェクト: fedden/jukebox
def sample_single_window(zs,
                         labels_1,
                         labels_2,
                         sampling_kwargs,
                         level,
                         prior,
                         start,
                         hps,
                         total_length=1):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:, start:end]

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(
        f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens"
    )
    print_once(
        f"{round( (start+sample_tokens)/total_length*100.0 )}%-ish, level {level}"
    )
    if new_tokens <= 0:
        # Nothing new to sample
        return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    # set y offset, sample_length and lyrics tokens
    y1 = prior.get_y(labels_1, start)
    y2 = prior.get_y(labels_2, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']

    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y1_list = split_batch(y1, n_samples, max_batch_size)
    y2_list = split_batch(y2, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y1_i, y2_i in zip(z_list, z_conds_list, y1_list,
                                          y2_list):
        z_samples_i = prior.sample(n_samples=z_i.shape[0],
                                   z=z_i,
                                   z_conds=z_conds_i,
                                   y1=y1_i,
                                   y2=y2_i,
                                   **sampling_kwargs)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:, -new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs