Exemplo n.º 1
0
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length//prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
        zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0])

        logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0:
            alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps)
        save_html(logdir, x, zs, labels[-1], alignments, hps)
    return zs
Exemplo n.º 2
0
def _sample(zs, labels_1, labels_2, sampling_kwargs, priors, sample_levels,
            hps):
    alignments = None
    for level in reversed(sample_levels):
        prior = priors[level]
        prior.cuda()
        empty_cache()

        # Set correct total_length, hop_length, labels and sampling_kwargs for level
        assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}"
        total_length = hps.sample_length // prior.raw_to_tokens
        hop_length = int(hps.hop_fraction[level] * prior.n_ctx)
        zs = sample_level(zs, labels_1[level], labels_2[level],
                          sampling_kwargs[level], level, prior, total_length,
                          hop_length, hps)

        prior.cpu()
        empty_cache()

        # Decode sample
        x = prior.decode(zs[level:],
                         start_level=level,
                         bs_chunks=zs[level].shape[0])

        if dist.get_world_size() > 1:
            logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}"
        else:
            logdir = f"{hps.name}/level_{level}"
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        t.save(
            dict(zs=zs, labels=labels_1, sampling_kwargs=sampling_kwargs, x=x),
            f"{logdir}/data.pth.tar")
        save_wav(logdir, x, hps.sr)
        if alignments is None and priors[
                -1] is not None and priors[-1].n_tokens > 0 and not isinstance(
                    priors[-1].labeller, EmptyLabeller):
            try:
                labels_1[-1], priors[-1], sampling_kwargs[-1]['fp16']
            except:
                import ipdb
                ipdb.set_trace()
            alignments = get_alignment(x, zs, labels_1[-1], priors[-1],
                                       sampling_kwargs[-1]['fp16'], hps)
        # don't care
        # save_html(logdir, x, zs, labels_1[-1], alignments, hps)
    return zs