Exemple #1
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once(f"Sampling level {level}")
    ll = get_starts(total_length, prior.n_ctx, hop_length)
    if total_length >= prior.n_ctx:
        for ii, start in enumerate(
                get_starts(total_length, prior.n_ctx, hop_length)):
            print('*', ii + 1, '/', len(ll), '*')
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
Exemple #2
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once('')
    print_once(f"-- SAMPLING LEVEL {level} --")
    if total_length >= prior.n_ctx:
        num_starts = len(get_starts(total_length, prior.n_ctx, hop_length))
        start_index = 0
        for start in get_starts(total_length, prior.n_ctx, hop_length):
            start_index += 1
            print_once(f'- Level {level} [{start_index}/{num_starts}]')
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
Exemple #3
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length, hop_length, hps):
    print_once(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
        for start in get_starts(total_length, prior.n_ctx, hop_length):
            zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior, total_length, hps)
    return zs
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
        starts = get_starts(total_length, prior.n_ctx, hop_length)
        counterr = 0
        x = None
        for start in starts:
            counterr += 1
            datea = datetime.now()
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
            if newtosample and counterr < len(starts):
                del x
                x = None
                prior.cpu()
                empty_cache()
                x = prior.decode(zs[level:],
                                 start_level=level,
                                 bs_chunks=zs[level].shape[0])
                logdir = f"{hps.name}/level_{level}"
                if not os.path.exists(logdir):
                    os.makedirs(logdir)
                t.save(
                    dict(zs=zs,
                         labels=labels,
                         sampling_kwargs=sampling_kwargs,
                         x=x), f"{logdir}/data.pth.tar")
                save_wav(logdir, x, hps.sr)
                del x
                prior.cuda()
                empty_cache()
                x = None
            dateb = datetime.now()
            timex = ((dateb - datea).total_seconds() / 60.0) * (len(starts) -
                                                                counterr)
            print(f"Step " + colored(counterr, 'blue') + "/" +
                  colored(len(starts), 'red') + " ~ New to Sample: " +
                  str(newtosample) + " ~ estimated remaining minutes: " +
                  (colored('???', 'yellow'),
                   colored(timex, 'magenta'))[counterr > 1 and newtosample])
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
Exemple #5
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once(f"Sampling level {level}")

    if hps.hop_fraction[level] == 1 and level != hps.levels - 1:
        print('hop_fraction 1 detected, enabling speed upsampling')
        print('thank me later, -MichaelsLab')
        # to speed up sampling we simply break up the batches and paralellize within them as new batches
        batch_size = sampling_kwargs['max_batch_size']
        hop_length *= batch_size
    else:
        batch_size = 1

    if total_length >= prior.n_ctx:
        for start in get_starts(total_length, prior.n_ctx * batch_size,
                                hop_length):
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
Exemple #6
0
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length,
                 hop_length, hps):
    print_once(f"Sampling level {level}")
    if total_length >= prior.n_ctx:
        for start in get_starts(total_length, prior.n_ctx, hop_length):
            if start <= hps.restart:
                print(f"Skipping: {start}")
                continue
            hps.restart = start
            zs = sample_single_window(zs, labels, sampling_kwargs, level,
                                      prior, start, hps)
            logdir = f"{hps.name}/level_{level}"
            if not os.path.exists(logdir):
                os.makedirs(logdir)
            t.save(
                dict(zs=zs,
                     labels=labels,
                     sampling_kwargs=sampling_kwargs,
                     x=None,
                     level=level), f"{logdir}/data_part.pth.tar")
    else:
        zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior,
                                   total_length, hps)
    return zs
def get_alignment(x, zs, labels, prior, fp16, hps):
    level = hps.levels - 1  # Top level used
    n_ctx, n_tokens = prior.n_ctx, prior.n_tokens
    z = zs[level]
    bs, total_length = z.shape[0], z.shape[1]
    if total_length < n_ctx:
        padding_length = n_ctx - total_length
        z = t.cat([
            z,
            t.zeros(bs, n_ctx - total_length, dtype=z.dtype, device=z.device)
        ],
                  dim=1)
        total_length = z.shape[1]
    else:
        padding_length = 0

    hop_length = int(hps.hop_fraction[level] * prior.n_ctx)
    n_head = prior.prior.transformer.n_head
    alignment_head, alignment_layer = prior.alignment_head, prior.alignment_layer
    attn_layers = set([alignment_layer])
    alignment_hops = {}
    indices_hops = {}

    prior.cuda()
    empty_cache()
    for start in get_starts(total_length, n_ctx, hop_length):
        end = start + n_ctx

        # set y offset, sample_length and lyrics tokens
        y, indices_hop = prior.get_y(labels, start, get_indices=True)
        assert len(indices_hop) == bs
        for indices in indices_hop:
            assert len(indices) == n_tokens

        z_bs = t.chunk(z, bs, dim=0)
        y_bs = t.chunk(y, bs, dim=0)
        w_hops = []
        for z_i, y_i in zip(z_bs, y_bs):
            w_hop = prior.z_forward(z_i[:, start:end], [],
                                    y_i,
                                    fp16=fp16,
                                    get_attn_weights=attn_layers)
            assert len(w_hop) == 1
            w_hops.append(w_hop[0][:, alignment_head])
            del w_hop
        w = t.cat(w_hops, dim=0)
        del w_hops
        assert_shape(w, (bs, n_ctx, n_tokens))
        alignment_hop = w.float().cpu().numpy()
        assert_shape(alignment_hop, (bs, n_ctx, n_tokens))
        del w

        # alignment_hop has shape (bs, n_ctx, n_tokens)
        # indices_hop is a list of len=bs, each entry of len hps.n_tokens
        indices_hops[start] = indices_hop
        alignment_hops[start] = alignment_hop
    prior.cpu()
    empty_cache()

    # Combine attn for each hop into attn for full range
    # Use indices to place them into correct place for corresponding source tokens
    alignments = []
    for item in range(bs):
        # Note each item has different length lyrics
        full_tokens = labels['info'][item]['full_tokens']
        alignment = np.zeros((total_length, len(full_tokens) + 1))
        for start in reversed(get_starts(total_length, n_ctx, hop_length)):
            end = start + n_ctx
            alignment_hop = alignment_hops[start][item]
            indices = indices_hops[start][item]
            assert len(indices) == n_tokens
            assert alignment_hop.shape == (n_ctx, n_tokens)
            alignment[start:end, indices] = alignment_hop
        alignment = alignment[:total_length - padding_length, :
                              -1]  # remove token padding, and last lyric index
        alignments.append(alignment)
    return alignments