def evaluate(model, orig_model, logger, metrics, data_processor, hps): model.eval() orig_model.eval() if hps.prior: _print_keys = dict(l="loss", bpd="bpd") else: _print_keys = dict(l="loss", rl="recons_loss", sl="spectral_loss") with t.no_grad(): for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (i == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) x_out, loss, _metrics = model(x, **forw_kwargs) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item( ) # Make sure to call to free graph # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(f"test_{key}", val, x.shape[0]) with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) logger.set_postfix(**{ print_key: _metrics[key] for print_key, key in _print_keys.items() }) for key, val in _metrics.items(): logger.add_scalar(f"test_{key}", metrics.avg(f"test_{key}")) logger.close_range() return {key: metrics.avg(f"test_{key}") for key in _metrics.keys()}
def test_dataset_loader(): from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess from jukebox.hparams import setup_hparams from jukebox.data.files_dataset import FilesAudioDataset hps = setup_hparams("teeny", {}) hps.sr = 22050 # 44100 hps.hop_length = 512 hps.labels = False hps.channels = 2 hps.aug_shift = False hps.bs = 2 hps.nworkers = 2 # Getting 20 it/s with 2 workers, 10 it/s with 1 worker print(hps) dataset = hps.dataset root = hps.root from tensorboardX import SummaryWriter sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr] writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs') dataset = FilesAudioDataset(hps) print("Length of dataset", len(dataset)) # Torch Loader collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0) sampler = DistributedSampler(dataset) train_loader = DataLoader(dataset, batch_size=hps.bs, num_workers=hps.nworkers, pin_memory=False, sampler=sampler, drop_last=True, collate_fn=collate_fn) dist.barrier() sampler.set_epoch(0) for i, x in enumerate(tqdm(train_loader)): x = x.to('cuda', non_blocking=True) for j, aud in enumerate(x): writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote in") x = audio_preprocess(x, hps) x = audio_postprocess(x, hps) for j, aud in enumerate(x): writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote out") dist.barrier() break
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps): model.train() orig_model.train() if hps.prior: _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss") else: _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk") print_all(data_processor.train_loader) print_all(len(data_processor.train_loader)) for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (logger.iters % hps.save_iters == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # Forward x_out, loss, _metrics = model(x, **forw_kwargs) # Backward loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()), scalar=scalar, fp16=hps.fp16, logger=logger) # Skip step if overflow grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX) if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0: zero_grad(orig_model) continue # Step opt. Divide by scale to include clipping and fp16 scaling logger.step() opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale)) zero_grad(orig_model) lr = hps.lr if shd is None else shd.get_lr()[0] if shd is not None: shd.step() if ema is not None: ema.step() next_lr = hps.lr if shd is None else shd.get_lr()[0] finished_training = (next_lr == 0.0) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph _metrics["gn"] = grad_norm _metrics["lr"] = lr _metrics["lg_loss_scale"] = np.log2(scale) # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(key, val, x.shape[0]) if logger.iters % hps.log_steps == 0: logger.add_scalar(key, _metrics[key]) # Save checkpoint with t.no_grad(): if hps.save and (logger.iters % hps.save_iters == 1 or finished_training): if ema is not None: ema.swap() orig_model.eval() name = 'latest' if hps.prior else f'step_{logger.iters}' if dist.get_rank() % 8 == 0: save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps) orig_model.train() if ema is not None: ema.swap() # Sample with t.no_grad(): if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training: if hps.prior: sample_prior(orig_model, ema, logger, x_in, y, hps) # Input/Output with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) print("Hey there") logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) print("by there") if finished_training: dist.barrier() exit() logger.close_range() return {key: metrics.avg(key) for key in _metrics.keys()}
def compute_metrics(vqvae, hps, output_folder): json_path = '/home/kevin/feedforward/mp3s_for_jukebox_test.json' mp3_dict = load_json(json_path) csv = { "client": [], "media_id": [], "external_id": [], "s3_key": [], 'recons_loss_l3': [], 'spectral_loss_l3': [], 'multispectral_loss_l3': [], 'recons_loss_l2': [], 'spectral_loss_l2': [], 'multispectral_loss_l2': [], 'recons_loss_l1': [], 'spectral_loss_l1': [], 'multispectral_loss_l1': [], 'recons_loss': [], 'spectral_loss': [], 'multispectral_loss': [], 'spectral_convergence': [], 'l2_loss': [], 'l1_loss': [], 'linf_loss': [], 'commit_loss': [] } print("sample_length", vqvae.sample_length) print('multipliers', vqvae.multipliers) print('x_shape', vqvae.x_shape) print('downsamples', vqvae.downsamples) print('hop lengths', vqvae.hop_lengths) print('z shapes', vqvae.z_shapes) print('levels', vqvae.levels) print(len(vqvae.encoders)) # print(vqvae.encoders[0]) forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # hps.ngpus = dist.get_world_size() hps.argv = " ".join(sys.argv) hps.bs_sample = hps.nworkers = hps.bs = 1 for client_name in mp3_dict: if not os.path.exists(os.path.join(output_folder, client_name)): os.makedirs(os.path.join(output_folder, client_name)) if not os.path.exists(os.path.join(output_folder, client_name, 'audio')): os.makedirs(os.path.join(output_folder, client_name, 'audio')) if not os.path.exists(os.path.join(output_folder, client_name, 'spec')): os.makedirs(os.path.join(output_folder, client_name, 'spec')) for mp3_metadata in mp3_dict[ client_name]: # 'external_id', 'media_id', 'num_samples', 's3_key' print(mp3_metadata) s3_key = mp3_metadata['s3_key'] filename = s3_key.split('/')[-1] mp3_path = os.path.join(audio_mp3s_folder, s3_key) mp3, _ = librosa.core.load(mp3_path, sr=44100) librosa.output.write_wav("{}/{}.wav".format( os.path.join(output_folder, client_name, 'audio'), filename.split('.')[0]), mp3[:881920], sr=44100) hps.bandwidth = get_bandwidth(mp3, hps) inputs = torch.tensor(mp3[:881920]).view(1, -1, 1).to(device) mp3_spec = spec(inputs.squeeze().cpu(), hps).numpy() # save_spec_plot(mp3_spec, os.path.join(output_folder, client_name, 'spec', filename.split('.')[0] + '.png'), # title=filename.split('.')[0]) inputs = audio_preprocess(inputs, hps) x_outs, loss, _metrics = vqvae( inputs, **forw_kwargs, return_all_x_outs=True) # x_outs with top level first # print("Loss: {}".format(loss)) # print("Metrics:", _metrics) out_specs = [] for i, x_out in enumerate( reversed(x_outs)): # level 0 (bottom) first x_out_np = x_out.cpu().squeeze().numpy() librosa.output.write_wav("{}/{}_recon{}.wav".format( os.path.join(output_folder, client_name, 'audio'), filename.split('.')[0], i), x_out_np, sr=44100) x_out_spec = spec(x_out.squeeze().cpu(), hps).numpy() out_specs.append(x_out_spec) save_spec_plot([mp3_spec] + out_specs, os.path.join(output_folder, client_name, 'spec', filename.split('.')[0] + '.png'), title=filename.split('.')[0]) csv['client'].append(client_name) csv['media_id'].append(mp3_metadata['media_id']) csv['external_id'].append(mp3_metadata['external_id']) csv['s3_key'].append(mp3_metadata['s3_key']) for k, v in _metrics.items(): csv[k].append(float(v.squeeze().cpu().numpy())) pd.DataFrame(csv).to_csv(os.path.join(output_folder, 'metrics.csv'))