def ancestral_sample(self, n_samples, z_conds=None, y=None): z = t.zeros((n_samples, self.n_ctx), dtype=t.long, device='cuda') + \ t.arange(0, self.n_ctx, dtype=t.long, device='cuda').view(1, self.n_ctx) if z_conds is not None: z_cond = z_conds[0] assert_shape(z_cond, (n_samples, self.n_ctx // 4)) assert (z // 4 == repeat(z_cond, 4, 1)).all( ), f'z: {z}, z_cond: {z_cond}, diff: {(z // 4) - repeat(z_cond, 4, 1)}' return z
def primed_sample(self, n_samples, z, z_conds=None, y=None): prime = z.shape[1] assert_shape(z, (n_samples, prime)) start = z[:, -1:] + 1 z_rest = (t.arange(0, self.n_ctx - prime, dtype=t.long, device='cuda').view(1, self.n_ctx - prime) + start).view(n_samples, self.n_ctx - prime) z = t.cat([z, z_rest], dim=1) if z_conds is not None: z_cond = z_conds[0] assert_shape(z_cond, (n_samples, self.n_ctx // 4)) assert (z // 4 == repeat(z_cond, 4, 1)).all( ), f'z: {z}, z_cond: {z_cond}, diff: {(z // 4) - repeat(z_cond, 4, 1)}' return z
def forward(self, x): N, T = x.shape[0], x.shape[-1] emb = self.input_emb_width assert_shape(x, (N, emb, T)) xs = [] # 64, 32, ... iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t) for level, down_t, stride_t in iterator: level_block = self.level_blocks[level] x = level_block(x) emb, T = self.output_emb_width, T // (stride_t ** down_t) assert_shape(x, (N, emb, T)) xs.append(x) return xs
def get_encoder_kv(self, prime, fp16=False, sample=False): if self.n_tokens != 0 and self.use_tokens: if sample: self.prime_prior.cuda() N = prime.shape[0] prime_acts = self.prime_prior(prime, None, None, None, fp16=fp16) assert_shape(prime_acts, (N, self.prime_loss_dims, self.prime_acts_width)) assert prime_acts.dtype == t.float, f'Expected t.float, got {prime_acts.dtype}' encoder_kv = self.prime_state_ln(self.prime_state_proj(prime_acts)) assert encoder_kv.dtype == t.float, f'Expected t.float, got {encoder_kv.dtype}' if sample: self.prime_prior.cpu() if fp16: encoder_kv = encoder_kv.half() else: encoder_kv = None return encoder_kv
def forward(self, xs, all_levels=True): if all_levels: assert len(xs) == self.levels else: assert len(xs) == 1 x = xs[-1] N, T = x.shape[0], x.shape[-1] emb = self.output_emb_width assert_shape(x, (N, emb, T)) # 32, 64 ... iterator = reversed(list(zip(list(range(self.levels)), self.downs_t, self.strides_t))) for level, down_t, stride_t in iterator: level_block = self.level_blocks[level] x = level_block(x) emb, T = self.output_emb_width, T * (stride_t ** down_t) assert_shape(x, (N, emb, T)) if level != 0 and all_levels: x = x + xs[level - 1] x = self.out(x) return x
def prior_preprocess(self, xs, conds): N = xs[0].shape[0] for i in range(len(xs)): x, shape, dims = xs[i], self.prior_shapes[i], self.prior_dims[i] bins, bins_shift = int(self.prior_bins[i]), int( self.prior_bins_shift[i]) assert isinstance(x, t.cuda.LongTensor), x assert (0 <= x).all() and (x < bins).all() #assert_shape(x, (N, *shape)) xs[i] = (xs[i] + bins_shift).view(N, -1) for i in range(len(conds)): cond, shape, dims = conds[i], self.prior_shapes[ i], self.prior_dims[i] if cond is not None: assert_shape(cond, (N, dims, self.prior_width)) else: conds[i] = t.zeros((N, dims, self.prior_width), dtype=t.float, device='cuda') return t.cat(xs, dim=1), t.cat(conds, dim=1)
def forward(self, y): assert len(y.shape) == 2, f"Expected shape with 2 dims, got {y.shape}" assert y.shape[-1] == 4 + self.max_bow_genre_size, f"Expected shape (N,{4 + self.max_bow_genre_size}), got {y.shape}" assert isinstance(y, t.cuda.LongTensor), f"Expected dtype {t.cuda.LongTensor}, got {y.dtype}" N = y.shape[0] total_length, offset, length, artist, genre = y[:,0:1], y[:,1:2], y[:,2:3], y[:,3:4], y[:,4:] # Start embedding of length 1 artist_emb = self.artist_emb(artist) # Empty genre slots are denoted by -1. We mask these out. mask = (genre >= 0).float().unsqueeze(2) genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True) start_emb = genre_emb + artist_emb assert_shape(start_emb, (N, 1, self.out_width)) # Pos embedding of length n_ctx if self.include_time_signal: start, end = offset, offset + length total_length, start, end = total_length.float(), start.float(), end.float() pos_emb = self.total_length_emb(total_length) + self.absolute_pos_emb(start, end) + self.relative_pos_emb(start/total_length, end/total_length) assert_shape(pos_emb, (N, self.n_time, self.out_width)) else: pos_emb = None return start_emb, pos_emb
def forward(self, x, x_cond=None): N = x.shape[0] assert_shape(x, (N, *self.x_shape)) if x_cond is not None: assert_shape(x_cond, (N, *self.x_shape, self.width)) else: x_cond = 0.0 # Embed x x = x.long() x = self.x_emb(x) assert_shape(x, (N, *self.x_shape, self.width)) x = x + x_cond # Run conditioner x = self.preprocess(x) x = self.cond(x) x = self.postprocess(x) x = self.ln(x) return x
def sample(self, n_samples, z=None, z_conds=None, y=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, chunk_size=None, sample_tokens=None): N = n_samples if z is not None: assert z.shape[ 0] == N, f"Expected shape ({N},**), got shape {z.shape}" if y is not None: assert y.shape[ 0] == N, f"Expected shape ({N},**), got shape {y.shape}" if z_conds is not None: for z_cond in z_conds: assert z_cond.shape[ 0] == N, f"Expected shape ({N},**), got shape {z_cond.shape}" no_past_context = (z is None or z.shape[1] == 0) if dist.get_rank() == 0: name = {True: 'Ancestral', False: 'Primed'}[no_past_context] print( f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}" ) with t.no_grad(): # Currently x_cond only uses immediately above layer x_cond, y_cond, prime = self.get_cond(z_conds, y) if self.single_enc_dec: # assert chunk_size % self.prime_loss_dims == 0. TODO: Check if needed if no_past_context: z, x_cond = self.prior_preprocess([prime], [None, x_cond]) else: z, x_cond = self.prior_preprocess([prime, z], [None, x_cond]) if sample_tokens is not None: sample_tokens += self.n_tokens z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens) z = self.prior_postprocess(z) else: encoder_kv = self.get_encoder_kv(prime, fp16=fp16, sample=True) if no_past_context: z = self.prior.sample(n_samples, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, sample_tokens=sample_tokens) else: z = self.prior.primed_sample(n_samples, z, x_cond, y_cond, encoder_kv, fp16=fp16, temp=temp, top_k=top_k, top_p=top_p, chunk_size=chunk_size, sample_tokens=sample_tokens) if sample_tokens is None: assert_shape(z, (N, *self.z_shape)) return z
def get_alignment(x, zs, labels, prior, fp16, hps): level = hps.levels - 1 # Top level used n_ctx, n_tokens = prior.n_ctx, prior.n_tokens z = zs[level] bs, total_length = z.shape[0], z.shape[1] if total_length < n_ctx: padding_length = n_ctx - total_length z = t.cat([z, t.zeros(bs, n_ctx - total_length, dtype=z.dtype, device=z.device)], dim=1) total_length = z.shape[1] else: padding_length = 0 hop_length = int(hps.hop_fraction[level]*prior.n_ctx) n_head = prior.prior.transformer.n_head alignment_head, alignment_layer = prior.alignment_head, prior.alignment_layer attn_layers = set([alignment_layer]) alignment_hops = {} indices_hops = {} prior.cuda() empty_cache() for start in get_starts(total_length, n_ctx, hop_length): end = start + n_ctx # set y offset, sample_length and lyrics tokens y, indices_hop = prior.get_y(labels, start, get_indices=True) assert len(indices_hop) == bs for indices in indices_hop: assert len(indices) == n_tokens z_bs = t.chunk(z, bs, dim=0) y_bs = t.chunk(y, bs, dim=0) w_hops = [] for z_i, y_i in zip(z_bs, y_bs): w_hop = prior.z_forward(z_i[:,start:end], [], y_i, fp16=fp16, get_attn_weights=attn_layers) assert len(w_hop) == 1 w_hops.append(w_hop[0][:, alignment_head]) del w_hop w = t.cat(w_hops, dim=0) del w_hops assert_shape(w, (bs, n_ctx, n_tokens)) alignment_hop = w.float().cpu().numpy() assert_shape(alignment_hop, (bs, n_ctx, n_tokens)) del w # alignment_hop has shape (bs, n_ctx, n_tokens) # indices_hop is a list of len=bs, each entry of len hps.n_tokens indices_hops[start] = indices_hop alignment_hops[start] = alignment_hop prior.cpu() empty_cache() # Combine attn for each hop into attn for full range # Use indices to place them into correct place for corresponding source tokens alignments = [] for item in range(bs): # Note each item has different length lyrics full_tokens = labels['info'][item]['full_tokens'] alignment = np.zeros((total_length, len(full_tokens) + 1)) for start in reversed(get_starts(total_length, n_ctx, hop_length)): end = start + n_ctx alignment_hop = alignment_hops[start][item] indices = indices_hops[start][item] assert len(indices) == n_tokens assert alignment_hop.shape == (n_ctx, n_tokens) alignment[start:end,indices] = alignment_hop alignment = alignment[:total_length - padding_length,:-1] # remove token padding, and last lyric index alignments.append(alignment) return alignments