def test_lower_triangular(self): m = TriangularCausalMask(3) self.assertTrue(m.lower_triangular) self.assertTrue(torch.all(m.bool_matrix == (torch.tensor([ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]) > 0))) m = FullMask(torch.tensor([ [1, 0, 0], [1, 1, 0], [1, 1, 1] ]) > 0) self.assertTrue(m.lower_triangular) m = FullMask(torch.tensor([ [1, 0, 1], [1, 1, 0], [1, 1, 1] ]) > 0) self.assertFalse(m.lower_triangular) m = LengthMask(torch.tensor([1, 1, 3])) self.assertFalse(m.lower_triangular) m = LengthMask(torch.tensor([1, 2, 3])) self.assertTrue(m.lower_triangular) m = LengthMask(torch.tensor([1, 2, 3]), max_len=4) self.assertTrue(m.lower_triangular)
def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"): return (torch.rand(N, L, H, E).to(device), torch.rand(N, S, H, E).to(device), torch.rand(N, S, H, D).to(device), TriangularCausalMask(L, device=device), FullMask(N, L, device=device), FullMask(N, S, device=device))
def forward(self, x, lengths=None, attn_kwargs=None): attn_mask = TriangularCausalMask(x.size(1), device=x.device) if lengths is not None: length_mask = LengthMask(lengths, device=x.device) else: length_mask = None attn_kwargs = dict(attn_kwargs) if attn_kwargs else {} if self._spe and self.share_pe and attn_kwargs.get('pos_code', None) is None: attn_kwargs['pos_code'] = self.spe(x.shape[:-1]) out = x for l in range(self.n_layer): layer_attn_kwargs = dict(attn_kwargs) if self._spe and not self.share_pe and layer_attn_kwargs.get('pos_code', None) is None: layer_attn_kwargs['pos_code'] = self.spe[l](x.shape[:-1]) out = self.decoder_layers[l]( out, attn_mask=attn_mask, length_mask=length_mask, attn_kwargs=layer_attn_kwargs ) return out
def test_correctness(self): # Prepare the inputs N = 10 H = 4 E = 25 M = 64 L = 100 q = torch.rand(N, L, H, E) k = torch.rand(N, L, H, E) v = torch.rand(N, L, H, M) m1 = TriangularCausalMask(L) m2 = LengthMask(torch.full((N,), L)) m3 = LengthMask(torch.full((N,), L)) att = FullAttention() rec_att = RecurrentFullAttention() att.eval() rec_att.eval() v1 = att(q, k, v, m1, m2, m3) v2 = [] memory = None for i in range(L): v2i, memory = rec_att(q[:, i], k[:, i], v[:, i], memory) v2.append(v2i) v2 = torch.stack(v2, dim=1) self.assertLess(torch.abs(v1-v2).max(), 1e-5)
def forward(self, x): x = self.fourier_coefficient_embedding(x) x = self.pos_embedding(x) triangular_mask = TriangularCausalMask(x.shape[1], device=x.device) y_hat = self.encoder(x, attn_mask=triangular_mask) y_amp = self.predictor_amp(y_hat) y_phase = torch.tanh(self.predictor_phase(y_hat)) return torch.cat([y_amp, y_phase], dim=-1)
def forward(self, x): x = x.view(x.shape[0], -1) x = self.value_embedding(x) x = self.pos_embedding(x) triangular_mask = TriangularCausalMask(x.shape[1], device=x.device) y_hat = self.transformer(x, attn_mask=triangular_mask) y_hat = self.predictor(y_hat) return y_hat
def test_compare_with_batch(self): N = 10 L = 42 S = 100 D = 1024 E = D // 4 x = torch.rand(N, L, D) m = torch.rand(N, S, D) tests = [("full", FullAttention, FullAttention, RecurrentFullAttention, RecurrentCrossFullAttention), ("linear", partial(CausalLinearAttention, E), partial(LinearAttention, E), partial(RecurrentLinearAttention, E), partial(RecurrentCrossLinearAttention, E))] for name, a1, a2, a3, a4 in tests: dec = TransformerDecoder([ TransformerDecoderLayer(AttentionLayer(a1(), D, 4), AttentionLayer(a2(), D, 4), D) for i in range(4) ]) rdec = RecurrentTransformerDecoder([ RecurrentTransformerDecoderLayer( RecurrentAttentionLayer(a3(), D, 4), RecurrentCrossAttentionLayer(a4(), D, 4), D) for i in range(4) ]) dec.eval() rdec.eval() rdec.load_state_dict(dec.state_dict()) x_mask = TriangularCausalMask(L) x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64)) m_mask = FullMask(L, S) m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64)) y1 = dec(x, m, x_mask=x_mask, x_length_mask=x_length, memory_mask=m_mask, memory_length_mask=m_length) state = None y2 = [] for i in range(L): y2i, state = rdec(x[:, i], m, memory_length_mask=m_length, state=state) y2.append(y2i) y2 = torch.stack(y2, dim=1) self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
def forward_hidden(self, x, memory=None, is_training=True): ''' linear transformer: b x s x f x.shape=(bs, nf) ''' # embeddings emb_tempo = self.word_emb_tempo(x[..., 0]) emb_chord = self.word_emb_chord(x[..., 1]) emb_barbeat = self.word_emb_barbeat(x[..., 2]) emb_type = self.word_emb_type(x[..., 3]) emb_pitch = self.word_emb_pitch(x[..., 4]) emb_duration = self.word_emb_duration(x[..., 5]) emb_velocity = self.word_emb_velocity(x[..., 6]) embs = torch.cat([ emb_tempo, emb_chord, emb_barbeat, emb_type, emb_pitch, emb_duration, emb_velocity, ], dim=-1) emb_linear = self.in_linear(embs) pos_emb = self.pos_emb(emb_linear) # assert False # transformer if is_training: # mask attn_mask = TriangularCausalMask(pos_emb.size(1), device=x.device) h = self.transformer_encoder(pos_emb, attn_mask) # y: b x s x d_model # project type y_type = self.proj_type(h) return h, y_type else: pos_emb = pos_emb.squeeze(0) h, memory = self.transformer_encoder( pos_emb, memory=memory) # y: s x d_model # project type y_type = self.proj_type(h) return h, y_type, memory
def forward(self, x, meta): ar = torch.arange(self.seq_len).float().type_as(x) relative_pos = self.positional_encoder(ar).unsqueeze(0).repeat( [x.shape[0], 1, 1]) att_mask = TriangularCausalMask(self.seq_len, self.get_device()) seq_mask = (x[..., 0] > 0).sum(dim=1).long() seq_mask = LengthMask(seq_mask, max_len=self.seq_len) x = self.model_projection(x) * math.sqrt(self.project_dimension) x = x + relative_pos regress_embeddings = self.transformer(x, att_mask, seq_mask) pred = self.output_projection(regress_embeddings) return pred
def forward(self, x, lengths=None, attn_kwargs=None): attn_mask = TriangularCausalMask(x.size(1), device=x.device) if lengths is not None: length_mask = LengthMask(lengths, device=x.device) else: length_mask = None attn_kwargs = dict(attn_kwargs) if attn_kwargs else {} out = x for l in range(self.n_layer): # print (out.size()) out = self.decoder_layers[l](out, attn_mask=attn_mask, length_mask=length_mask, attn_kwargs=attn_kwargs) return out
def sample(self, n_samples, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, temp=1.0, top_k=0, top_p=0.0, get_preds=False, sample_tokens=None): assert self.training == False if sample_tokens is None: sample_tokens = self.input_dims // 4 N, D = n_samples, self.input_dims if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda() with t.no_grad(): xs, x_emb, x = [], [], None if get_preds: preds = [] for sample_t in get_range(range(0, sample_tokens)): x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond) #self.transformer.check_cache(n_samples, sample_t, fp16) x_emb.append(x) x = t.cat(x_emb, dim=1) triangular_mask = TriangularCausalMask(x.shape[1], device=x.device) x = self.transformer( x, attn_mask=triangular_mask)[:, -1:, :] # Transformer if self.add_cond_after_transformer: x = x + cond assert x.shape == (n_samples, 1, self.width) x = self.x_out(x) # Predictions if get_preds: preds.append(x.clone()) # Adjust logits x = x / temp x = filter_logits(x, top_k=top_k, top_p=top_p) x = t.distributions.Categorical( logits=x).sample() # Sample and replace x assert x.shape == (n_samples, 1) xs.append(x.clone()) del x, x_emb #self.transformer.del_cache() x = t.cat(xs, dim=1) if get_preds: preds = t.cat(preds, dim=1) x = self.postprocess(x, sample_tokens) if get_preds: return x, preds else: return x
def forward(self, x, x_cond=None, y_cond=None, encoder_kv=None, fp16=False, loss_full=False, encode=False, get_preds=False, get_acts=False, get_sep_loss=False): # Preprocess. with t.no_grad(): x = self.preprocess(x) N, D = x.shape assert isinstance(x, t.cuda.LongTensor) assert (0 <= x).all() and (x < self.bins).all() if self.y_cond: assert y_cond is not None assert y_cond.shape == (N, 1, self.width) else: assert y_cond is None if self.x_cond: assert x_cond is not None assert x_cond.shape == (N, D, self.width) or x_cond.shape == ( N, 1, self.width ), f"{x_cond.shape} != {(N, D, self.width)} nor {(N, 1, self.width)}. Did you pass the correct --sample_length?" else: assert x_cond is None x_cond = t.zeros((N, 1, self.width), device=x.device, dtype=t.float) x_t = x # Target x = self.x_emb(x) # X emb x = roll(x, 1) # Shift by 1, and fill in start token if self.y_cond: x[:, 0] = y_cond.view(N, self.width) else: x[:, 0] = self.start_token x = self.x_emb_dropout(x) + self.pos_emb_dropout( self.pos_emb()) + x_cond # Pos emb and dropout triangular_mask = TriangularCausalMask(x.shape[1], device=x.device) x = self.transformer(x, attn_mask=triangular_mask) # Transformer if self.add_cond_after_transformer: # Piped doesnt add x_cond x = x + x_cond acts = x if self.only_encode: return x x = self.x_out(x) # Predictions if get_sep_loss: assert self.prime_len is not None x_prime = x[:, :self.prime_len].reshape(-1, self.bins) x_gen = x[:, self.prime_len:].reshape(-1, self.bins) prime_loss = F.cross_entropy( x_prime, x_t[:, :self.prime_len].reshape(-1)) / np.log(2.) gen_loss = F.cross_entropy( x_gen, x_t[:, self.prime_len:].reshape(-1)) / np.log(2.) loss = (prime_loss, gen_loss) # Note order! Prime is first else: loss = F.cross_entropy(x.view(-1, self.bins), x_t.view(-1)) / np.log(2.) # Loss if get_preds: return loss, x elif get_acts: return loss, acts else: return loss, None