def load_checkpoint(path): print("Loading checkpoint...") restore = path if restore[:5] == 'gs://': gs_path = restore local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:]) gdrive_path = os.path.join("/content/gdrive/My Drive/samples/", gs_path[5:]) print(f'local path: {local_path}') print(f'gdrive path: {gdrive_path}') if dist.get_rank() % 8 == 0: if os.path.exists(gdrive_path): print("Using priors on Google Drive") restore = gdrive_path elif os.path.exists( os.path.dirname(gdrive_path) ): print("Downloading priors to Google Drive") download(gs_path, gdrive_path) restore = gdrive_path else: print("Downloading from gce") if not os.path.exists(os.path.dirname(local_path)): os.makedirs(os.path.dirname(local_path)) if not os.path.exists(local_path): download(gs_path, local_path) restore = local_path dist.barrier() checkpoint = t.load(restore, map_location=t.device('cpu')) print("RS // Restored from {}".format(restore)) return checkpoint
def save_outputs(model, device, hps): # Check logits if hps.labels_v3: n_ctx = 6144 n_tokens = 384 prime_bins = 79 else: n_ctx = 8192 n_tokens = 512 prime_bins = 80 rng = t.random.manual_seed(0) x = 2 * t.rand((1, n_ctx * 8 * 4 * 4, 1), generator=rng, dtype=t.float).cuda() - 1.0 # -1 to 1 lyric_tokens = t.randint(0, prime_bins, (1, n_tokens), generator=rng, dtype=t.long).view(-1).numpy() artist_id = 10 genre_ids = [1] total_length = 2 * 2646000 offset = 2646000 vqvae, priors = make_model(model, device, hps) # encode vq_prior = priors[-1] zs = vq_prior.encode(x, start_level=0) x_ds = [ vq_prior.decode(zs[level:], start_level=level) for level in range(0, len(zs)) ] # priors data = dict(zs=zs, x_ds=x_ds) for level in range(len(priors)): print(f"Doing level {level}") if hps.labels_v3 and level != hps.levels - 1: print(f"Skipping level {level}") continue prior = priors[level] prior.cuda() x_in = x[:, :n_ctx * 8 * (4**level)] y_in = t.from_numpy( prior.labeller.get_y_from_ids(artist_id, genre_ids, lyric_tokens, total_length, offset)).view(1, -1).cuda().long() x_out, _, metrics = prior(x_in, y_in, fp16=hps.fp16, get_preds=True, decode=True) preds = metrics['preds'] data[level] = dict(x=x_in, y=y_in, x_out=x_out, preds=preds) prior.cpu() t.save(data, 'data.pth.tar') dist.barrier() print("Saved data") exit()
def test_dataset_loader(): from tqdm import tqdm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from jukebox.utils.audio_utils import audio_preprocess, audio_postprocess from jukebox.hparams import setup_hparams from jukebox.data.files_dataset import FilesAudioDataset hps = setup_hparams("teeny", {}) hps.sr = 22050 # 44100 hps.hop_length = 512 hps.labels = False hps.channels = 2 hps.aug_shift = False hps.bs = 2 hps.nworkers = 2 # Getting 20 it/s with 2 workers, 10 it/s with 1 worker print(hps) dataset = hps.dataset root = hps.root from tensorboardX import SummaryWriter sr = {22050: '22k', 44100: '44k', 48000: '48k'}[hps.sr] writer = SummaryWriter(f'{root}/{dataset}/logs/{sr}/logs') dataset = FilesAudioDataset(hps) print("Length of dataset", len(dataset)) # Torch Loader collate_fn = lambda batch: t.stack([t.from_numpy(b) for b in batch], 0) sampler = DistributedSampler(dataset) train_loader = DataLoader(dataset, batch_size=hps.bs, num_workers=hps.nworkers, pin_memory=False, sampler=sampler, drop_last=True, collate_fn=collate_fn) dist.barrier() sampler.set_epoch(0) for i, x in enumerate(tqdm(train_loader)): x = x.to('cuda', non_blocking=True) for j, aud in enumerate(x): writer.add_audio('in_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote in") x = audio_preprocess(x, hps) x = audio_postprocess(x, hps) for j, aud in enumerate(x): writer.add_audio('out_' + str(i * hps.bs + j), aud, 1, hps.sr) print("Wrote out") dist.barrier() break
def load_checkpoint(path): restore = path if restore[:5] == 'gs://': gs_path = restore local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:]) if dist.get_rank() % 8 == 0: print("Downloading from gce") if not os.path.exists(os.path.dirname(local_path)): os.makedirs(os.path.dirname(local_path)) if not os.path.exists(local_path): download(gs_path, local_path) restore = local_path dist.barrier() checkpoint = t.load(restore, map_location=t.device('cpu')) print("Restored from {}".format(restore)) return checkpoint
def load_checkpoint(path): restore = path if restore.startswith(REMOTE_PREFIX): remote_path = restore local_path = os.path.join(os.path.expanduser("~/.cache"), remote_path[len(REMOTE_PREFIX):]) if dist.get_rank() % 8 == 0: print("Downloading from azure") if not os.path.exists(os.path.dirname(local_path)): os.makedirs(os.path.dirname(local_path)) if not os.path.exists(local_path): download(remote_path, local_path) restore = local_path dist.barrier() checkpoint = custom_load(restore, map_location=t.device('cpu')) print("Restored from {}".format(restore)) return checkpoint
def run(hps="teeny", port=29500, **kwargs): from app.jukebox.utils.dist_utils import setup_dist_from_mpi rank, local_rank, device = setup_dist_from_mpi(port=port) hps = setup_hparams(hps, kwargs) hps.ngpus = dist.get_world_size() hps.argv = " ".join(sys.argv) hps.bs_sample = hps.nworkers = hps.bs # Setup dataset data_processor = DataProcessor(hps) # Setup models vqvae = make_vqvae(hps, device) print_once(f"Parameters VQVAE:{count_parameters(vqvae)}") if hps.prior: prior = make_prior(hps, vqvae, device) print_once(f"Parameters Prior:{count_parameters(prior)}") model = prior else: model = vqvae # Setup opt, ema and distributed_model. opt, shd, scalar = get_optimizer(model, hps) ema = get_ema(model, hps) distributed_model = get_ddp(model, hps) logger, metrics = init_logging(hps, local_rank, rank) logger.iters = model.step # Run training, eval, sample for epoch in range(hps.curr_epoch, hps.epochs): metrics.reset() data_processor.set_epoch(epoch) if hps.train: train_metrics = train(distributed_model, model, opt, shd, scalar, ema, logger, metrics, data_processor, hps) train_metrics['epoch'] = epoch if rank == 0: print( 'Train', ' '.join([ f'{key}: {val:0.4f}' for key, val in train_metrics.items() ])) dist.barrier() if hps.test: if ema: ema.swap() test_metrics = evaluate(distributed_model, model, logger, metrics, data_processor, hps) test_metrics['epoch'] = epoch if rank == 0: print( 'Ema', ' '.join([ f'{key}: {val:0.4f}' for key, val in test_metrics.items() ])) dist.barrier() if ema: ema.swap() dist.barrier()
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps): model.train() orig_model.train() if hps.prior: _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss") else: _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk") print_all(data_processor.train_loader) print_all(len(data_processor.train_loader)) for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (logger.iters % hps.save_iters == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # Forward x_out, loss, _metrics = model(x, **forw_kwargs) # Backward loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()), scalar=scalar, fp16=hps.fp16, logger=logger) # Skip step if overflow grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX) if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0: zero_grad(orig_model) continue # Step opt. Divide by scale to include clipping and fp16 scaling logger.step() opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale)) zero_grad(orig_model) lr = hps.lr if shd is None else shd.get_lr()[0] if shd is not None: shd.step() if ema is not None: ema.step() next_lr = hps.lr if shd is None else shd.get_lr()[0] finished_training = (next_lr == 0.0) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph _metrics["gn"] = grad_norm _metrics["lr"] = lr _metrics["lg_loss_scale"] = np.log2(scale) # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(key, val, x.shape[0]) if logger.iters % hps.log_steps == 0: logger.add_scalar(key, _metrics[key]) # Save checkpoint with t.no_grad(): if hps.save and (logger.iters % hps.save_iters == 1 or finished_training): if ema is not None: ema.swap() orig_model.eval() name = 'latest' if hps.prior else f'step_{logger.iters}' if dist.get_rank() % 8 == 0: save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps) orig_model.train() if ema is not None: ema.swap() # Sample with t.no_grad(): if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training: if hps.prior: sample_prior(orig_model, ema, logger, x_in, y, hps) # Input/Output with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) print("Hey there") logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) print("by there") if finished_training: dist.barrier() exit() logger.close_range() return {key: metrics.avg(key) for key in _metrics.keys()}