def sample_prior(orig_model, ema, logger, x_in, y, hps): if ema is not None: ema.swap() orig_model.eval() x_in = x_in[:hps.bs_sample] bs = x_in.shape[0] zs_in = orig_model.encode(x_in, start_level=0, bs_chunks=bs) assert len(zs_in) == hps.levels x_ds = [orig_model.decode(zs_in[level:], start_level=level, bs_chunks=bs) for level in range(0, hps.levels)] if not hps.labels: y = None elif hps.level == (hps.levels - 1): # Topmost level labels in order y = y[:hps.bs_sample] # t.ones((hps.bs_sample, 1), device=y.device, dtype=t.long) * dist.get_rank() else: # Other levels keep labels to match x_cond y = y[:hps.bs_sample] # Temp 1.0 _, *z_conds = orig_model.encode(x_in, bs_chunks=bs) z = orig_model.sample(hps.bs_sample, z_conds=z_conds, y=y, fp16=False, temp=1.0) x_sample = orig_model.decode([z, *z_conds], bs_chunks=bs) log_aud(logger, 'sample_x_T1', x_sample, hps) if hps.prior and hps.labels: log_labels(logger, orig_model.labeller, f'sample_x_T1', allgather(y.cuda()), hps) # Recons for i in range(len(x_ds)): log_aud(logger, f'x_ds_start_{i}', x_ds[i], hps) orig_model.train() if ema is not None: ema.swap() logger.flush()
def log_inputs(orig_model, logger, x_in, y, x_out, hps, tag="train"): print(f"Logging {tag} inputs/ouputs") log_aud(logger, f'{tag}_x_in', x_in, hps) log_aud(logger, f'{tag}_x_out', x_out, hps) if hps.prior: if hps.labels: log_labels(logger, orig_model.labeller, f'{tag}_y_in', allgather(y.cuda()), hps) else: zs_in = orig_model.encode(x_in, start_level=0) x_ds = [orig_model.decode(zs_in[level:], start_level=level) for level in range(0, hps.levels)] for i in range(len(x_ds)): log_aud(logger, f'{tag}_x_ds_start_{i}', x_ds[i], hps) logger.flush()
def prepare_aud(x, hps): x = audio_postprocess(x.detach().contiguous(), hps) return allgather(x)