Esempio n. 1
0
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start,
                         hps):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:, start:end].to(prior.device)

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(
        f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens"
    )

    if new_tokens <= 0:
        # Nothing new to sample
        return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    if z_conds != None:
        for k in range(len(z_conds)):
            z_conds[k] = z_conds[k].to(prior.device)

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']

    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
        z_samples_i = prior.sample(n_samples=z_i.shape[0],
                                   z=z_i,
                                   z_conds=z_conds_i,
                                   y=y_i,
                                   **sampling_kwargs)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:, -new_tokens:].cpu()
    del z
    del y
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs
Esempio n. 2
0
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:,start:end]

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens")

    if new_tokens <= 0:
        # Nothing new to sample
        return zs
    
    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    # set y offset, sample_length and lyrics tokens
    y = prior.get_y(labels, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']


    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y_list = split_batch(y, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list):
        midi_path = r"C:\Users\Yousef\Desktop\UNiz\MidiDataset\Cleaned\acdc\Big Balls.mid"
        midi = load_sample_midi(midi_path)
        z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs, midi=midi)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:,-new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs
Esempio n. 3
0
def sample_single_window(zs,
                         labels_1,
                         labels_2,
                         sampling_kwargs,
                         level,
                         prior,
                         start,
                         hps,
                         total_length=1):
    n_samples = hps.n_samples
    n_ctx = prior.n_ctx
    end = start + n_ctx

    # get z already sampled at current level
    z = zs[level][:, start:end]

    if 'sample_tokens' in sampling_kwargs:
        # Support sampling a window shorter than n_ctx
        sample_tokens = sampling_kwargs['sample_tokens']
    else:
        sample_tokens = (end - start)
    conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1]

    print_once(
        f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens"
    )
    print_once(
        f"{round( (start+sample_tokens)/total_length*100.0 )}%-ish, level {level}"
    )
    if new_tokens <= 0:
        # Nothing new to sample
        return zs

    # get z_conds from level above
    z_conds = prior.get_z_conds(zs, start, end)

    # set y offset, sample_length and lyrics tokens
    y1 = prior.get_y(labels_1, start)
    y2 = prior.get_y(labels_2, start)

    empty_cache()

    max_batch_size = sampling_kwargs['max_batch_size']
    del sampling_kwargs['max_batch_size']

    z_list = split_batch(z, n_samples, max_batch_size)
    z_conds_list = split_batch(z_conds, n_samples, max_batch_size)
    y1_list = split_batch(y1, n_samples, max_batch_size)
    y2_list = split_batch(y2, n_samples, max_batch_size)
    z_samples = []
    for z_i, z_conds_i, y1_i, y2_i in zip(z_list, z_conds_list, y1_list,
                                          y2_list):
        z_samples_i = prior.sample(n_samples=z_i.shape[0],
                                   z=z_i,
                                   z_conds=z_conds_i,
                                   y1=y1_i,
                                   y2=y2_i,
                                   **sampling_kwargs)
        z_samples.append(z_samples_i)
    z = t.cat(z_samples, dim=0)

    sampling_kwargs['max_batch_size'] = max_batch_size

    # Update z with new sample
    z_new = z[:, -new_tokens:]
    zs[level] = t.cat([zs[level], z_new], dim=1)
    return zs