Beispiel #1
0
def sample_prior(orig_model, ema, logger, x_in, y, hps):
    if ema is not None: ema.swap()
    orig_model.eval()

    x_in = x_in[:hps.bs_sample]
    bs = x_in.shape[0]
    zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs)
    assert len(zs_in) == hps.levels
    x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)]

    if not hps.labels:
        y = None
    elif hps.level == (hps.levels - 1):
        # Topmost level labels in order
        y = y[:hps.bs_sample]  # t.ones((hps.bs_sample, 1), device=y.device, dtype=t.long) * dist.get_rank()
    else:
        # Other levels keep labels to match x_cond
        y = y[:hps.bs_sample]

    # Temp 1.0
    _, *z_conds = orig_model.encode(x_in, bs_chunks=bs)
    z = orig_model.sample(hps.bs_sample, z_conds=z_conds, y=y, fp16=False, temp=1.0)
    x_sample = orig_model.decode([z, *z_conds], bs_chunks=bs)

    log_aud(logger, 'sample_x_T1', x_sample, hps)
    if hps.prior and hps.labels:
        log_labels(logger, orig_model.labeller, f'sample_x_T1', allgather(y.cuda()), hps)

    # Recons
    for i in range(len(x_ds)):
        log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps)
    orig_model.train()
    if ema is not None: ema.swap()
    logger.flush()
Beispiel #2
0
def log_inputs(orig_model, logger, x_in, y, x_out, hps, tag="train"):
    print(f"Logging {tag} inputs/ouputs")
    log_aud(logger, f'{tag}_x_in', x_in, hps)
    log_aud(logger, f'{tag}_x_out', x_out, hps)
    if hps.prior:
        if hps.labels:
            log_labels(logger, orig_model.labeller, f'{tag}_y_in', allgather(y.cuda()), hps)
    else:
        zs_in = orig_model.encode(x_in, start_level=0)
        x_ds = [orig_model.decode(zs_in[level:], start_level=level) for level in range(0, hps.levels)]
        for i in range(len(x_ds)):
            log_aud(logger, f'{tag}_x_ds_start_{i}', x_ds[i], hps)
    logger.flush()
Beispiel #3
0
def prepare_aud(x, hps):
    x = audio_postprocess(x.detach().contiguous(), hps)
    return allgather(x)