def sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps): n_samples = hps.n_samples n_ctx = prior.n_ctx end = start + n_ctx # get z already sampled at current level z = zs[level][:, start:end].to(prior.device) if 'sample_tokens' in sampling_kwargs: # Support sampling a window shorter than n_ctx sample_tokens = sampling_kwargs['sample_tokens'] else: sample_tokens = (end - start) conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1] print_once( f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens" ) if new_tokens <= 0: # Nothing new to sample return zs # get z_conds from level above z_conds = prior.get_z_conds(zs, start, end) if z_conds != None: for k in range(len(z_conds)): z_conds[k] = z_conds[k].to(prior.device) # set y offset, sample_length and lyrics tokens y = prior.get_y(labels, start) empty_cache() max_batch_size = sampling_kwargs['max_batch_size'] del sampling_kwargs['max_batch_size'] z_list = split_batch(z, n_samples, max_batch_size) z_conds_list = split_batch(z_conds, n_samples, max_batch_size) y_list = split_batch(y, n_samples, max_batch_size) z_samples = [] for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list): z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs) z_samples.append(z_samples_i) z = t.cat(z_samples, dim=0) sampling_kwargs['max_batch_size'] = max_batch_size # Update z with new sample z_new = z[:, -new_tokens:].cpu() del z del y zs[level] = t.cat([zs[level], z_new], dim=1) return zs
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps): alignments = None for level in reversed(sample_levels): prior = priors[level] prior.cuda() empty_cache() # Set correct total_length, hop_length, labels and sampling_kwargs for level assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}" total_length = hps.sample_length//prior.raw_to_tokens hop_length = int(hps.hop_fraction[level]*prior.n_ctx) zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps) prior.cpu() empty_cache() # Decode sample x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0]) logdir = f"{hps.name}/level_{level}" if not os.path.exists(logdir): os.makedirs(logdir) t.save(dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar") save_wav(logdir, x, hps.sr) if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0: alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps) save_html(logdir, x, zs, labels[-1], alignments, hps) return zs
def run(**kwargs): from jukebox.utils.dist_utils import setup_dist_from_mpi model = "1b_lyrics" port = 29500 rank, local_rank, device = setup_dist_from_mpi(port=port) hps = Hyperparams() hps.sr = 44100 hps.n_samples = 1 hps.name = kwargs["sample_name"] chunk_size = 32 max_batch_size = 16 hps.levels = 3 hps.hop_fraction = [.5,.5,.125] vqvae, *priors = MODELS[model] vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device) top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device) sample_length_in_seconds = kwargs["sample_length"] hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate' metas = [dict( artist = kwargs["artist"], genre = kwargs["genre"], total_length = hps.sample_length, offset = 0, lyrics = kwargs["lyrics"], ), ] * hps.n_samples labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')] sampling_temperature = .98 lower_batch_size = 16 max_batch_size = 16 lower_level_chunk_size = 32 chunk_size = 32 sampling_kwargs = [ dict(temp=.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size, chunk_size=lower_level_chunk_size), dict(temp=sampling_temperature, fp16=True, max_batch_size=max_batch_size, chunk_size=chunk_size) ] zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cuda') for _ in range(len(priors))] zs = _sample(zs, labels, sampling_kwargs, [None, None, top_prior], [2], hps) del top_prior empty_cache() top_prior=None upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]] labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
def _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps): alignments = None for level in reversed(sample_levels): prior = priors[level] prior.cuda() empty_cache() # Set correct total_length, hop_length, labels and sampling_kwargs for level assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}" total_length = hps.sample_length // prior.raw_to_tokens hop_length = int(hps.hop_fraction[level] * prior.n_ctx) zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps) prior.cpu() empty_cache() # Decode sample x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0]) if dist.get_world_size() > 1: logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}" else: logdir = f"{hps.name}/level_{level}" if not os.path.exists(logdir): os.makedirs(logdir) t.save( dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar") save_wav(logdir, x, hps.sr) #if alignments is None and priors[-1] is not None and priors[-1].n_tokens > 0 and not isinstance(priors[-1].labeller, EmptyLabeller): #alignments = get_alignment(x, zs, labels[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps) lepath = hps.name if level == 2: for filex in glob(os.path.join(lepath + '/level_2', 'item_*.wav')): os.rename(filex, filex.replace('item_', lepath.split('/')[-1] + '-')) if level == 1: for filex in glob(os.path.join(lepath + '/level_1', 'item_*.wav')): os.rename( filex, filex.replace('item_', lepath.split('/')[-1] + '-L1-')) if level == 0: for filex in glob(os.path.join(lepath + '/level_0', 'item_*.wav')): os.rename( filex, filex.replace('item_', lepath.split('/')[-1] + '-L0-')) #save_html(logdir, x, zs, labels[-1], alignments, hps) return zs
def sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps): n_samples = hps.n_samples n_ctx = prior.n_ctx end = start + n_ctx # get z already sampled at current level z = zs[level][:,start:end] if 'sample_tokens' in sampling_kwargs: # Support sampling a window shorter than n_ctx sample_tokens = sampling_kwargs['sample_tokens'] else: sample_tokens = (end - start) conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1] print_once(f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens") if new_tokens <= 0: # Nothing new to sample return zs # get z_conds from level above z_conds = prior.get_z_conds(zs, start, end) # set y offset, sample_length and lyrics tokens y = prior.get_y(labels, start) empty_cache() max_batch_size = sampling_kwargs['max_batch_size'] del sampling_kwargs['max_batch_size'] z_list = split_batch(z, n_samples, max_batch_size) z_conds_list = split_batch(z_conds, n_samples, max_batch_size) y_list = split_batch(y, n_samples, max_batch_size) z_samples = [] for z_i, z_conds_i, y_i in zip(z_list, z_conds_list, y_list): midi_path = r"C:\Users\Yousef\Desktop\UNiz\MidiDataset\Cleaned\acdc\Big Balls.mid" midi = load_sample_midi(midi_path) z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y=y_i, **sampling_kwargs, midi=midi) z_samples.append(z_samples_i) z = t.cat(z_samples, dim=0) sampling_kwargs['max_batch_size'] = max_batch_size # Update z with new sample z_new = z[:,-new_tokens:] zs[level] = t.cat([zs[level], z_new], dim=1) return zs
def _sample(zs, labels_1, labels_2, sampling_kwargs, priors, sample_levels, hps): alignments = None for level in reversed(sample_levels): prior = priors[level] prior.cuda() empty_cache() # Set correct total_length, hop_length, labels and sampling_kwargs for level assert hps.sample_length % prior.raw_to_tokens == 0, f"Expected sample_length {hps.sample_length} to be multiple of {prior.raw_to_tokens}" total_length = hps.sample_length // prior.raw_to_tokens hop_length = int(hps.hop_fraction[level] * prior.n_ctx) zs = sample_level(zs, labels_1[level], labels_2[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps) prior.cpu() empty_cache() # Decode sample x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0]) if dist.get_world_size() > 1: logdir = f"{hps.name}_rank_{dist.get_rank()}/level_{level}" else: logdir = f"{hps.name}/level_{level}" if not os.path.exists(logdir): os.makedirs(logdir) t.save( dict(zs=zs, labels=labels_1, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar") save_wav(logdir, x, hps.sr) if alignments is None and priors[ -1] is not None and priors[-1].n_tokens > 0 and not isinstance( priors[-1].labeller, EmptyLabeller): try: labels_1[-1], priors[-1], sampling_kwargs[-1]['fp16'] except: import ipdb ipdb.set_trace() alignments = get_alignment(x, zs, labels_1[-1], priors[-1], sampling_kwargs[-1]['fp16'], hps) # don't care # save_html(logdir, x, zs, labels_1[-1], alignments, hps) return zs
def sample_level(zs, labels, sampling_kwargs, level, prior, total_length, hop_length, hps): print_once(f"Sampling level {level}") if total_length >= prior.n_ctx: starts = get_starts(total_length, prior.n_ctx, hop_length) counterr = 0 x = None for start in starts: counterr += 1 datea = datetime.now() zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps) if newtosample and counterr < len(starts): del x x = None prior.cpu() empty_cache() x = prior.decode(zs[level:], start_level=level, bs_chunks=zs[level].shape[0]) logdir = f"{hps.name}/level_{level}" if not os.path.exists(logdir): os.makedirs(logdir) t.save( dict(zs=zs, labels=labels, sampling_kwargs=sampling_kwargs, x=x), f"{logdir}/data.pth.tar") save_wav(logdir, x, hps.sr) del x prior.cuda() empty_cache() x = None dateb = datetime.now() timex = ((dateb - datea).total_seconds() / 60.0) * (len(starts) - counterr) print(f"Step " + colored(counterr, 'blue') + "/" + colored(len(starts), 'red') + " ~ New to Sample: " + str(newtosample) + " ~ estimated remaining minutes: " + (colored('???', 'yellow'), colored(timex, 'magenta'))[counterr > 1 and newtosample]) else: zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior, total_length, hps) return zs
def get_alignment(x, zs, labels, prior, fp16, hps): level = hps.levels - 1 # Top level used n_ctx, n_tokens = prior.n_ctx, prior.n_tokens z = zs[level] bs, total_length = z.shape[0], z.shape[1] if total_length < n_ctx: padding_length = n_ctx - total_length z = t.cat([ z, t.zeros(bs, n_ctx - total_length, dtype=z.dtype, device=z.device) ], dim=1) total_length = z.shape[1] else: padding_length = 0 hop_length = int(hps.hop_fraction[level] * prior.n_ctx) n_head = prior.prior.transformer.n_head alignment_head, alignment_layer = prior.alignment_head, prior.alignment_layer attn_layers = set([alignment_layer]) alignment_hops = {} indices_hops = {} prior.cuda() empty_cache() for start in get_starts(total_length, n_ctx, hop_length): end = start + n_ctx # set y offset, sample_length and lyrics tokens y, indices_hop = prior.get_y(labels, start, get_indices=True) assert len(indices_hop) == bs for indices in indices_hop: assert len(indices) == n_tokens z_bs = t.chunk(z, bs, dim=0) y_bs = t.chunk(y, bs, dim=0) w_hops = [] for z_i, y_i in zip(z_bs, y_bs): w_hop = prior.z_forward(z_i[:, start:end], [], y_i, fp16=fp16, get_attn_weights=attn_layers) assert len(w_hop) == 1 w_hops.append(w_hop[0][:, alignment_head]) del w_hop w = t.cat(w_hops, dim=0) del w_hops assert_shape(w, (bs, n_ctx, n_tokens)) alignment_hop = w.float().cpu().numpy() assert_shape(alignment_hop, (bs, n_ctx, n_tokens)) del w # alignment_hop has shape (bs, n_ctx, n_tokens) # indices_hop is a list of len=bs, each entry of len hps.n_tokens indices_hops[start] = indices_hop alignment_hops[start] = alignment_hop prior.cpu() empty_cache() # Combine attn for each hop into attn for full range # Use indices to place them into correct place for corresponding source tokens alignments = [] for item in range(bs): # Note each item has different length lyrics full_tokens = labels['info'][item]['full_tokens'] alignment = np.zeros((total_length, len(full_tokens) + 1)) for start in reversed(get_starts(total_length, n_ctx, hop_length)): end = start + n_ctx alignment_hop = alignment_hops[start][item] indices = indices_hops[start][item] assert len(indices) == n_tokens assert alignment_hop.shape == (n_ctx, n_tokens) alignment[start:end, indices] = alignment_hop alignment = alignment[:total_length - padding_length, : -1] # remove token padding, and last lyric index alignments.append(alignment) return alignments
def primed_sample(self, n_samples, x, memory, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, get_preds=False, chunk_size=None, sample_tokens=None): assert self.training == False if sample_tokens is None: sample_tokens = self.input_dims # Preprocess. with t.no_grad(): x = self.preprocess(x) assert isinstance(x, t.cuda.LongTensor) assert (0 <= x).all() and (x < self.bins).all() assert x.shape[0] == n_samples xs = t.split(x, 1, dim=1) xs = list(xs) assert len(xs) < sample_tokens N, D = n_samples, self.input_dims if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda() with t.no_grad(): if get_preds: preds = [] x = xs[-1] assert x.shape == (n_samples, 1) empty_cache() for sample_t in get_range(range(len(xs), sample_tokens)): x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) #self.transformer.check_cache(n_samples, sample_t, fp16) x, memory = self.transformer(x[:, -1, :], memory) # Transformer x = t.unsqueeze(x, 1) if self.add_cond_after_transformer: x = x + cond assert x.shape == (n_samples, 1, self.width) x = self.x_out(x) # Predictions if get_preds: preds.append(x) # Adjust logits x = x / temp x = filter_logits(x, top_k=top_k, top_p=top_p) x = t.distributions.Categorical( logits=x).sample() # Sample and replace x assert x.shape == (n_samples, 1) xs.append(x.clone()) del x #self.transformer.del_cache() x = t.cat(xs, dim=1) if get_preds: preds = t.cat(preds, dim=1) x = self.postprocess(x, sample_tokens) if get_preds: return x, preds else: return x, memory
def primed_sample(self, n_samples, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, get_preds=False, chunk_size=None, sample_tokens=None): assert self.training == False if sample_tokens is None: sample_tokens = self.input_dims # Preprocess. with t.no_grad(): x = self.preprocess(x) assert isinstance(x, t.cuda.LongTensor) assert (0 <= x).all() and (x < self.bins).all() assert x.shape[0] == n_samples xs = t.split(x, 1, dim=1) xs = list(xs) assert len(xs) < sample_tokens N, D = n_samples, self.input_dims if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda() with t.no_grad(): if get_preds: preds = [] # Fill up key/value cache for past context by runing forward pass. # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage. if chunk_size is None: chunk_size = len(xs) #assert len(xs) % chunk_size == 0, f'expected {len(xs)} to be divisible by {chunk_size}' chunk_sizes = split_chunks(len(xs), chunk_size) x_primes = [] start = 0 x = None for current_chunk_size in get_range(chunk_sizes): xs_prime, conds_prime = [], [] for sample_t in range(start, start + current_chunk_size): x_prime, cond_prime = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) x = xs[sample_t] xs_prime.append(x_prime) conds_prime.append(cond_prime) start = start + current_chunk_size x_prime, cond_prime = t.cat(xs_prime, dim=1), t.cat(conds_prime, dim=1) assert x_prime.shape == (n_samples, current_chunk_size, self.width) assert cond_prime.shape == (n_samples, current_chunk_size, self.width) del xs_prime del conds_prime if not get_preds: del cond_prime x_prime = self.transformer(x_prime, encoder_kv=encoder_kv, sample=True, fp16=fp16) if get_preds: if self.add_cond_after_transformer: x_prime = x_prime + cond_prime assert x_prime.shape == (n_samples, current_chunk_size, self.width) del cond_prime x_primes.append(x_prime) else: del x_prime if get_preds: x_prime = t.cat(x_primes, dim=1) assert x_prime.shape == (n_samples, len(xs), self.width) x_prime = self.x_out(x_prime) # Predictions preds.append(x_prime) empty_cache() self.transformer.check_cache(n_samples, len(xs), fp16) x = xs[-1] assert x.shape == (n_samples, 1) empty_cache() for sample_t in get_range(range(len(xs), sample_tokens)): x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) self.transformer.check_cache(n_samples, sample_t, fp16) x = self.transformer(x, encoder_kv=encoder_kv, sample=True, fp16=fp16) # Transformer if self.add_cond_after_transformer: x = x + cond assert x.shape == (n_samples, 1, self.width) x = self.x_out(x) # Predictions if get_preds: preds.append(x) # Adjust logits x = x / temp x = filter_logits(x, top_k=top_k, top_p=top_p) x = t.distributions.Categorical( logits=x).sample() # Sample and replace x assert x.shape == (n_samples, 1) xs.append(x.clone()) del x self.transformer.del_cache() x = t.cat(xs, dim=1) if get_preds: preds = t.cat(preds, dim=1) x = self.postprocess(x, sample_tokens) if get_preds: return x, preds else: return x
def sample_single_window(zs, labels_1, labels_2, sampling_kwargs, level, prior, start, hps, total_length=1): n_samples = hps.n_samples n_ctx = prior.n_ctx end = start + n_ctx # get z already sampled at current level z = zs[level][:, start:end] if 'sample_tokens' in sampling_kwargs: # Support sampling a window shorter than n_ctx sample_tokens = sampling_kwargs['sample_tokens'] else: sample_tokens = (end - start) conditioning_tokens, new_tokens = z.shape[1], sample_tokens - z.shape[1] print_once( f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on {conditioning_tokens} tokens" ) print_once( f"{round( (start+sample_tokens)/total_length*100.0 )}%-ish, level {level}" ) if new_tokens <= 0: # Nothing new to sample return zs # get z_conds from level above z_conds = prior.get_z_conds(zs, start, end) # set y offset, sample_length and lyrics tokens y1 = prior.get_y(labels_1, start) y2 = prior.get_y(labels_2, start) empty_cache() max_batch_size = sampling_kwargs['max_batch_size'] del sampling_kwargs['max_batch_size'] z_list = split_batch(z, n_samples, max_batch_size) z_conds_list = split_batch(z_conds, n_samples, max_batch_size) y1_list = split_batch(y1, n_samples, max_batch_size) y2_list = split_batch(y2, n_samples, max_batch_size) z_samples = [] for z_i, z_conds_i, y1_i, y2_i in zip(z_list, z_conds_list, y1_list, y2_list): z_samples_i = prior.sample(n_samples=z_i.shape[0], z=z_i, z_conds=z_conds_i, y1=y1_i, y2=y2_i, **sampling_kwargs) z_samples.append(z_samples_i) z = t.cat(z_samples, dim=0) sampling_kwargs['max_batch_size'] = max_batch_size # Update z with new sample z_new = z[:, -new_tokens:] zs[level] = t.cat([zs[level], z_new], dim=1) return zs