def load_checkpoint(path): print("Loading checkpoint...") restore = path if restore[:5] == 'gs://': gs_path = restore local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:]) gdrive_path = os.path.join("/content/gdrive/My Drive/samples/", gs_path[5:]) print(f'local path: {local_path}') print(f'gdrive path: {gdrive_path}') if dist.get_rank() % 8 == 0: if os.path.exists(gdrive_path): print("Using priors on Google Drive") restore = gdrive_path elif os.path.exists( os.path.dirname(gdrive_path) ): print("Downloading priors to Google Drive") download(gs_path, gdrive_path) restore = gdrive_path else: print("Downloading from gce") if not os.path.exists(os.path.dirname(local_path)): os.makedirs(os.path.dirname(local_path)) if not os.path.exists(local_path): download(gs_path, local_path) restore = local_path dist.barrier() checkpoint = t.load(restore, map_location=t.device('cpu')) print("RS // Restored from {}".format(restore)) return checkpoint
def get_optimizer(model, hps): # Optimizer betas = (hps.beta1, hps.beta2) if hps.fp16_opt: opt = FP16FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps) else: opt = FusedAdam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, betas=betas, eps=hps.eps) # lr scheduler shd = get_lr_scheduler(opt, hps) # fp16 dynamic loss scaler scalar = None if hps.fp16: rank = dist.get_rank() local_rank = rank % 8 scalar = LossScalar(hps.fp16_loss_scale, scale_factor=2**(1. / hps.fp16_scale_window)) if local_rank == 0: print(scalar.__dict__) zero_grad(model) return opt, shd, scalar
def __init__(self, n_in, n_depth, m_conv=1.0, dilation_growth_rate=1, dilation_cycle=None, zero_out=False, res_scale=False, reverse_dilation=False, checkpoint_res=False): super().__init__() def _get_depth(depth): if dilation_cycle is None: return depth else: return depth % dilation_cycle blocks = [ ResConv1DBlock(n_in, int(m_conv * n_in), dilation=dilation_growth_rate**_get_depth(depth), zero_out=zero_out, res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth)) for depth in range(n_depth) ] if reverse_dilation: blocks = blocks[::-1] self.checkpoint_res = checkpoint_res if self.checkpoint_res == 1: if dist.get_rank() == 0: print("Checkpointing convs") self.blocks = nn.ModuleList(blocks) else: self.model = nn.Sequential(*blocks)
def get_ddp(model, hps): rank = dist.get_rank() local_rank = rank % 8 ddp = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False, bucket_cap_mb=hps.bucket) return ddp
def init_dataset(self, hps): # Load list of files and starts/durations files = librosa.util.find_files(f'{hps.audio_files_dir}', ['mp3', 'opus', 'm4a', 'aac', 'wav']) print_all(f"Found {len(files)} files. Getting durations") cache = dist.get_rank() % 8 == 0 if dist.is_available() else True durations = np.array([get_duration_sec(file, cache=cache) * self.sr for file in files]) # Could be approximate self.filter(files, durations) if self.labels: self.labeller = Labeller(hps.max_bow_genre_size, hps.n_tokens, self.sample_length, v3=hps.labels_v3)
def lr_lambda(step): if hps.lr_use_linear_decay: lr_scale = hps.lr_scale * min(1.0, step / hps.lr_warmup) decay = max(0.0, 1.0 - max(0.0, step - hps.lr_start_linear_decay) / hps.lr_decay) if decay == 0.0: if dist.get_rank() == 0: print("Reached end of training") return lr_scale * decay else: return hps.lr_scale * (hps.lr_gamma ** (step // hps.lr_decay)) * min(1.0, step / hps.lr_warmup)
def get_ema(model, hps): mu = hps.mu or (1. - (hps.bs * hps.ngpus/8.)/1000) ema = None if hps.ema and hps.train: if hps.cpu_ema: if dist.get_rank() == 0: print("Using CPU EMA") ema = CPUEMA(model.parameters(), mu=mu, freq=hps.cpu_ema_freq) elif hps.ema_fused: ema = FusedEMA(model.parameters(), mu=mu) else: ema = EMA(model.parameters(), mu=mu) return ema
def _setup_dist_from_mpi(master_addr, backend, port, n_attempts, verbose): from mpi4py import MPI # This must be imported in order to get e rrors from all ranks to show up mpi_rank = MPI.COMM_WORLD.Get_rank() mpi_size = MPI.COMM_WORLD.Get_size() os.environ["RANK"] = str(mpi_rank) os.environ["WORLD_SIZE"] = str(mpi_size) os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = str(port) os.environ["NCCL_LL_THRESHOLD"] = "0" os.environ["NCCL_NSOCKS_PERTHREAD"] = "2" os.environ["NCCL_SOCKET_NTHREADS"] = "8" # Pin this rank to a specific GPU on the node local_rank = mpi_rank % 8 if torch.cuda.is_available(): torch.cuda.set_device(local_rank) print(f'hmm {torch.cuda.is_available()}') if verbose: print(f"Connecting to master_addr: {master_addr}") # There is a race condition when initializing NCCL with a large number of ranks (e.g 500 ranks) # We guard against the failure and then retry for attempt_idx in range(n_attempts): try: dist.init_process_group(backend=backend, init_method=f"env://") assert dist.get_rank() == mpi_rank use_cuda = torch.cuda.is_available() print(f'Using cuda {use_cuda}') local_rank = mpi_rank % 8 device = torch.device( "cuda", local_rank) if use_cuda else torch.device("cpu") torch.cuda.set_device( local_rank) if torch.cuda.is_available() else {} return mpi_rank, local_rank, device except RuntimeError as e: print( f"Caught error during NCCL init (attempt {attempt_idx} of {n_attempts}): {e}" ) sleep(1 + (0.01 * mpi_rank)) # Sleep to avoid thundering herd pass raise RuntimeError("Failed to initialize NCCL")
def load_checkpoint(path): restore = path if restore[:5] == 'gs://': gs_path = restore local_path = os.path.join(os.path.expanduser("~/.cache"), gs_path[5:]) if dist.get_rank() % 8 == 0: print("Downloading from gce") if not os.path.exists(os.path.dirname(local_path)): os.makedirs(os.path.dirname(local_path)) if not os.path.exists(local_path): download(gs_path, local_path) restore = local_path dist.barrier() checkpoint = t.load(restore, map_location=t.device('cpu')) print("Restored from {}".format(restore)) return checkpoint
def load_checkpoint(path): restore = path if restore.startswith(REMOTE_PREFIX): remote_path = restore local_path = os.path.join(os.path.expanduser("~/.cache"), remote_path[len(REMOTE_PREFIX):]) if dist.get_rank() % 8 == 0: print("Downloading from azure") if not os.path.exists(os.path.dirname(local_path)): os.makedirs(os.path.dirname(local_path)) if not os.path.exists(local_path): download(remote_path, local_path) restore = local_path dist.barrier() checkpoint = custom_load(restore, map_location=t.device('cpu')) print("Restored from {}".format(restore)) return checkpoint
def sample(self, n_samples, z=None, memory=None, z_conds=None, y=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, chunk_size=None, sample_tokens=None, train=False): N = n_samples if z is not None: assert z.shape[0] == N, f"Expected shape ({N},**), got shape {z.shape}" if y is not None: assert y.shape[0] == N, f"Expected shape ({N},**), got shape {y.shape}" if z_conds is not None: for z_cond in z_conds: assert z_cond.shape[0] == N, f"Expected shape ({N},**), got shape {z_cond.shape}" no_past_context = (z is None or z.shape[1] == 0) if dist.get_rank() == 0: name = {True: 'Ancestral', False: 'Primed'}[no_past_context] print(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}") with t.no_grad(): # Currently x_cond only uses immediately above layer x_cond, y_cond, prime = self.get_cond(z_conds, y) if self.single_enc_dec: # assert chunk_size % self.prime_loss_dims == 0. TODO: Check if needed if no_past_context: z, x_cond = self.prior_preprocess([prime], [None, x_cond]) else: z, x_cond = self.prior_preprocess([prime, z], [None, x_cond]) if sample_tokens is not None: sample_tokens += self.n_tokens z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens) z = self.prior_postprocess(z) else: encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True) if no_past_context: if train: z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, sample_tokens=sample_tokens) else: z, memory = self.prior.sample_recurrent(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, sample_tokens=sample_tokens) else: z, memory = self.prior.primed_sample(n_samples, z, memory, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens) if sample_tokens is None: assert_shape(z, (N, *self.z_shape)) return z, memory
def calculate_bandwidth(dataset, hps, duration=600): hps = DefaultSTFTValues(hps) n_samples = int(dataset.sr * duration) l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() spec_norm_total, spec_nelem = 0.0, 0.0 while n_seen < n_samples: x = dataset[idx] if isinstance(x, (tuple, list)): x, y = x samples = x.astype(np.float64) stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) spec = np.absolute(stft) spec_norm_total += np.linalg.norm(spec) spec_nelem += 1 n_seen += int(np.prod(samples.shape)) l1 += np.sum(np.abs(samples)) total += np.sum(samples) total_sq += np.sum(samples**2) idx += max(16, dist.get_world_size()) if dist.is_available(): from jukebox.utils.dist_utils import allreduce n_seen = allreduce(n_seen) total = allreduce(total) total_sq = allreduce(total_sq) l1 = allreduce(l1) spec_nelem = allreduce(spec_nelem) spec_norm_total = allreduce(spec_norm_total) mean = total / n_seen bandwidth = dict(l2=total_sq / n_seen - mean**2, l1=l1 / n_seen, spec=spec_norm_total / spec_nelem) print_once(bandwidth) return bandwidth
def train(model, orig_model, opt, shd, scalar, ema, logger, metrics, data_processor, hps): model.train() orig_model.train() if hps.prior: _print_keys = dict(l="loss", bpd="bpd", gn="gn", g_l="gen_loss", p_l="prime_loss") else: _print_keys = dict(l="loss", sl="spectral_loss", rl="recons_loss", e="entropy", u="usage", uc="used_curr", gn="gn", pn="pn", dk="dk") print_all(data_processor.train_loader) print_all(len(data_processor.train_loader)) for i, x in logger.get_range(data_processor.train_loader): if isinstance(x, (tuple, list)): x, y = x else: y = None x = x.to('cuda', non_blocking=True) if y is not None: y = y.to('cuda', non_blocking=True) x_in = x = audio_preprocess(x, hps) log_input_output = (logger.iters % hps.save_iters == 0) if hps.prior: forw_kwargs = dict(y=y, fp16=hps.fp16, decode=log_input_output) else: forw_kwargs = dict(loss_fn=hps.loss_fn, hps=hps) # Forward x_out, loss, _metrics = model(x, **forw_kwargs) # Backward loss, scale, grad_norm, overflow_loss, overflow_grad = backward(loss=loss, params=list(model.parameters()), scalar=scalar, fp16=hps.fp16, logger=logger) # Skip step if overflow grad_norm = allreduce(grad_norm, op=dist.ReduceOp.MAX) if overflow_loss or overflow_grad or grad_norm > hps.ignore_grad_norm > 0: zero_grad(orig_model) continue # Step opt. Divide by scale to include clipping and fp16 scaling logger.step() opt.step(scale=clipped_grad_scale(grad_norm, hps.clip, scale)) zero_grad(orig_model) lr = hps.lr if shd is None else shd.get_lr()[0] if shd is not None: shd.step() if ema is not None: ema.step() next_lr = hps.lr if shd is None else shd.get_lr()[0] finished_training = (next_lr == 0.0) # Logging for key, val in _metrics.items(): _metrics[key] = val.item() _metrics["loss"] = loss = loss.item() * hps.iters_before_update # Make sure to call to free graph _metrics["gn"] = grad_norm _metrics["lr"] = lr _metrics["lg_loss_scale"] = np.log2(scale) # Average and log for key, val in _metrics.items(): _metrics[key] = metrics.update(key, val, x.shape[0]) if logger.iters % hps.log_steps == 0: logger.add_scalar(key, _metrics[key]) # Save checkpoint with t.no_grad(): if hps.save and (logger.iters % hps.save_iters == 1 or finished_training): if ema is not None: ema.swap() orig_model.eval() name = 'latest' if hps.prior else f'step_{logger.iters}' if dist.get_rank() % 8 == 0: save_checkpoint(logger, name, orig_model, opt, dict(step=logger.iters), hps) orig_model.train() if ema is not None: ema.swap() # Sample with t.no_grad(): if (logger.iters % 12000) in list(range(1, 1 + hps.iters_before_update)) or finished_training: if hps.prior: sample_prior(orig_model, ema, logger, x_in, y, hps) # Input/Output with t.no_grad(): if log_input_output: log_inputs(orig_model, logger, x_in, y, x_out, hps) print("Hey there") logger.set_postfix(**{print_key:_metrics[key] for print_key, key in _print_keys.items()}) print("by there") if finished_training: dist.barrier() exit() logger.close_range() return {key: metrics.avg(key) for key in _metrics.keys()}
def __init__(self, z_shapes, l_bins, encoder, decoder, level, downs_t, strides_t, labels, prior_kwargs, x_cond_kwargs, y_cond_kwargs, prime_kwargs, copy_input, labels_v3=False, merged_decoder=False, single_enc_dec=False): super().__init__() self.use_tokens = prime_kwargs.pop('use_tokens') self.n_tokens = prime_kwargs.pop('n_tokens') self.prime_loss_fraction = prime_kwargs.pop('prime_loss_fraction') self.copy_input = copy_input if self.copy_input: prime_kwargs['bins'] = l_bins self.z_shapes = z_shapes self.levels = len(self.z_shapes) self.z_shape = self.z_shapes[level] self.level = level assert level < self.levels, f"Total levels {self.levels}, got level {level}" self.l_bins = l_bins # Passing functions instead of the vqvae module to avoid getting params self.encoder = encoder self.decoder = decoder # X conditioning self.x_cond = (level != (self.levels - 1)) self.cond_level = level + 1 # Y conditioning self.y_cond = labels self.single_enc_dec = single_enc_dec # X conditioning if self.x_cond: self.conditioner_blocks = nn.ModuleList() conditioner_block = lambda _level: Conditioner( input_shape=z_shapes[_level], bins=l_bins, down_t=downs_t[_level], stride_t=strides_t[_level], **x_cond_kwargs) if dist.get_rank() == 0: print(f"Conditioning on 1 above level(s)") self.conditioner_blocks.append(conditioner_block(self.cond_level)) # Y conditioning if self.y_cond: self.n_time = self.z_shape[ 0] # Assuming STFT=TF order and raw=T1 order, so T is first dim self.y_emb = LabelConditioner(n_time=self.n_time, include_time_signal=not self.x_cond, **y_cond_kwargs) # Lyric conditioning if single_enc_dec: # Single encoder-decoder transformer self.prior_shapes = [(self.n_tokens, ), prior_kwargs.pop('input_shape')] self.prior_bins = [prime_kwargs['bins'], prior_kwargs.pop('bins')] self.prior_dims = [np.prod(shape) for shape in self.prior_shapes] self.prior_bins_shift = np.cumsum([0, *self.prior_bins])[:-1] self.prior_width = prior_kwargs['width'] print_once( f'Creating cond. autoregress with prior bins {self.prior_bins}, ' ) print_once(f'dims {self.prior_dims}, ') print_once(f'shift {self.prior_bins_shift}') print_once(f'input shape {sum(self.prior_dims)}') print_once(f'input bins {sum(self.prior_bins)}') print_once(f'Self copy is {self.copy_input}') self.prime_loss_dims, self.gen_loss_dims = self.prior_dims[ 0], self.prior_dims[1] self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims self.prior = ConditionalAutoregressive2D( input_shape=(sum(self.prior_dims), ), bins=sum(self.prior_bins), x_cond=(self.x_cond or self.y_cond), y_cond=True, prime_len=self.prime_loss_dims, **prior_kwargs) else: # Separate encoder-decoder transformer if self.n_tokens != 0 and self.use_tokens: from jukebox.transformer.ops import Conv1D prime_input_shape = (self.n_tokens, ) self.prime_loss_dims = np.prod(prime_input_shape) self.prime_acts_width, self.prime_state_width = prime_kwargs[ 'width'], prior_kwargs['width'] self.prime_prior = ConditionalAutoregressive2D( input_shape=prime_input_shape, x_cond=False, y_cond=False, only_encode=True, **prime_kwargs) self.prime_state_proj = Conv1D( self.prime_acts_width, self.prime_state_width, init_scale=prime_kwargs['init_scale']) self.prime_state_ln = LayerNorm(self.prime_state_width) self.prime_bins = prime_kwargs['bins'] self.prime_x_out = nn.Linear(self.prime_state_width, self.prime_bins, bias=False) nn.init.normal_(self.prime_x_out.weight, std=0.02 * prior_kwargs['init_scale']) else: self.prime_loss_dims = 0 self.gen_loss_dims = np.prod(self.z_shape) self.total_loss_dims = self.prime_loss_dims + self.gen_loss_dims self.prior = ConditionalAutoregressive2D( x_cond=(self.x_cond or self.y_cond), y_cond=self.y_cond, encoder_dims=self.prime_loss_dims, merged_decoder=merged_decoder, **prior_kwargs) self.n_ctx = self.gen_loss_dims self.downsamples = calculate_strides(strides_t, downs_t) self.cond_downsample = self.downsamples[ level + 1] if level != self.levels - 1 else None self.raw_to_tokens = np.prod(self.downsamples[:level + 1]) self.sample_length = self.n_ctx * self.raw_to_tokens if labels: self.labels_v3 = labels_v3 self.labeller = Labeller(self.y_emb.max_bow_genre_size, self.n_tokens, self.sample_length, v3=self.labels_v3) print( f"Level:{level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample length:{self.sample_length}" )
def print_once(msg): if (not dist.is_available()) or dist.get_rank() == 0: print(msg)
def print_all(msg): if (not dist.is_available()): print(msg) elif dist.get_rank() % 8 == 0: print(f'{dist.get_rank()//8}: {msg}')
def get_range(x): if dist.get_rank() == 0: return def_tqdm(x) else: return x