Beispiel #1
0
    def test_lower_triangular(self):
        m = TriangularCausalMask(3)
        self.assertTrue(m.lower_triangular)
        self.assertTrue(torch.all(m.bool_matrix == (torch.tensor([
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)))

        m = FullMask(torch.tensor([
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)
        self.assertTrue(m.lower_triangular)

        m = FullMask(torch.tensor([
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1]
        ]) > 0)
        self.assertFalse(m.lower_triangular)

        m = LengthMask(torch.tensor([1, 1, 3]))
        self.assertFalse(m.lower_triangular)
        m = LengthMask(torch.tensor([1, 2, 3]))
        self.assertTrue(m.lower_triangular)
        m = LengthMask(torch.tensor([1, 2, 3]), max_len=4)
        self.assertTrue(m.lower_triangular)
Beispiel #2
0
 def _get_inputs(self, N=10, L=5, S=8, H=4, E=32, D=64, device="cpu"):
     return (torch.rand(N, L, H, E).to(device), torch.rand(N, S, H,
                                                           E).to(device),
             torch.rand(N, S, H,
                        D).to(device), TriangularCausalMask(L,
                                                            device=device),
             FullMask(N, L, device=device), FullMask(N, S, device=device))
  def forward(self, x, lengths=None, attn_kwargs=None):
    attn_mask = TriangularCausalMask(x.size(1), device=x.device)

    if lengths is not None:
      length_mask = LengthMask(lengths, device=x.device)
    else:
      length_mask = None

    attn_kwargs = dict(attn_kwargs) if attn_kwargs else {}

    if self._spe and self.share_pe and attn_kwargs.get('pos_code', None) is None:
      attn_kwargs['pos_code'] = self.spe(x.shape[:-1])

    out = x
    for l in range(self.n_layer):
      layer_attn_kwargs = dict(attn_kwargs)
      if self._spe and not self.share_pe and layer_attn_kwargs.get('pos_code', None) is None:
        layer_attn_kwargs['pos_code'] = self.spe[l](x.shape[:-1])

      out = self.decoder_layers[l](
        out, 
        attn_mask=attn_mask, 
        length_mask=length_mask, 
        attn_kwargs=layer_attn_kwargs
      )

    return out
Beispiel #4
0
    def test_correctness(self):
        # Prepare the inputs
        N = 10
        H = 4
        E = 25
        M = 64
        L = 100
        q = torch.rand(N, L, H, E)
        k = torch.rand(N, L, H, E)
        v = torch.rand(N, L, H, M)
        m1 = TriangularCausalMask(L)
        m2 = LengthMask(torch.full((N,), L))
        m3 = LengthMask(torch.full((N,), L))
        att = FullAttention()
        rec_att = RecurrentFullAttention()
        att.eval()
        rec_att.eval()

        v1 = att(q, k, v, m1, m2, m3)
        v2 = []
        memory = None
        for i in range(L):
            v2i, memory = rec_att(q[:, i], k[:, i], v[:, i], memory)
            v2.append(v2i)
        v2 = torch.stack(v2, dim=1)
        self.assertLess(torch.abs(v1-v2).max(), 1e-5)
 def forward(self, x):
     x = self.fourier_coefficient_embedding(x)
     x = self.pos_embedding(x)
     triangular_mask = TriangularCausalMask(x.shape[1], device=x.device)
     y_hat = self.encoder(x, attn_mask=triangular_mask)
     y_amp = self.predictor_amp(y_hat)
     y_phase = torch.tanh(self.predictor_phase(y_hat))
     return torch.cat([y_amp, y_phase], dim=-1)
Beispiel #6
0
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.value_embedding(x)
        x = self.pos_embedding(x)
        triangular_mask = TriangularCausalMask(x.shape[1], device=x.device)
        y_hat = self.transformer(x, attn_mask=triangular_mask)
        y_hat = self.predictor(y_hat)

        return y_hat
Beispiel #7
0
    def test_compare_with_batch(self):
        N = 10
        L = 42
        S = 100
        D = 1024
        E = D // 4
        x = torch.rand(N, L, D)
        m = torch.rand(N, S, D)

        tests = [("full", FullAttention, FullAttention, RecurrentFullAttention,
                  RecurrentCrossFullAttention),
                 ("linear", partial(CausalLinearAttention,
                                    E), partial(LinearAttention, E),
                  partial(RecurrentLinearAttention,
                          E), partial(RecurrentCrossLinearAttention, E))]

        for name, a1, a2, a3, a4 in tests:
            dec = TransformerDecoder([
                TransformerDecoderLayer(AttentionLayer(a1(), D, 4),
                                        AttentionLayer(a2(), D, 4), D)
                for i in range(4)
            ])
            rdec = RecurrentTransformerDecoder([
                RecurrentTransformerDecoderLayer(
                    RecurrentAttentionLayer(a3(), D, 4),
                    RecurrentCrossAttentionLayer(a4(), D, 4), D)
                for i in range(4)
            ])
            dec.eval()
            rdec.eval()
            rdec.load_state_dict(dec.state_dict())

            x_mask = TriangularCausalMask(L)
            x_length = LengthMask(torch.full((N, ), L, dtype=torch.int64))
            m_mask = FullMask(L, S)
            m_length = LengthMask(torch.full((N, ), S, dtype=torch.int64))

            y1 = dec(x,
                     m,
                     x_mask=x_mask,
                     x_length_mask=x_length,
                     memory_mask=m_mask,
                     memory_length_mask=m_length)
            state = None
            y2 = []
            for i in range(L):
                y2i, state = rdec(x[:, i],
                                  m,
                                  memory_length_mask=m_length,
                                  state=state)
                y2.append(y2i)
            y2 = torch.stack(y2, dim=1)

            self.assertLess(torch.abs(y1 - y2).max(), 1e-5)
Beispiel #8
0
    def forward_hidden(self, x, memory=None, is_training=True):
        '''
        linear transformer: b x s x f
        x.shape=(bs, nf)
        '''

        # embeddings
        emb_tempo = self.word_emb_tempo(x[..., 0])
        emb_chord = self.word_emb_chord(x[..., 1])
        emb_barbeat = self.word_emb_barbeat(x[..., 2])
        emb_type = self.word_emb_type(x[..., 3])
        emb_pitch = self.word_emb_pitch(x[..., 4])
        emb_duration = self.word_emb_duration(x[..., 5])
        emb_velocity = self.word_emb_velocity(x[..., 6])

        embs = torch.cat([
            emb_tempo,
            emb_chord,
            emb_barbeat,
            emb_type,
            emb_pitch,
            emb_duration,
            emb_velocity,
        ],
                         dim=-1)

        emb_linear = self.in_linear(embs)
        pos_emb = self.pos_emb(emb_linear)

        # assert False

        # transformer
        if is_training:
            # mask
            attn_mask = TriangularCausalMask(pos_emb.size(1), device=x.device)
            h = self.transformer_encoder(pos_emb,
                                         attn_mask)  # y: b x s x d_model

            # project type
            y_type = self.proj_type(h)
            return h, y_type
        else:
            pos_emb = pos_emb.squeeze(0)
            h, memory = self.transformer_encoder(
                pos_emb, memory=memory)  # y: s x d_model

            # project type
            y_type = self.proj_type(h)
            return h, y_type, memory
Beispiel #9
0
    def forward(self, x, meta):
        ar = torch.arange(self.seq_len).float().type_as(x)
        relative_pos = self.positional_encoder(ar).unsqueeze(0).repeat(
            [x.shape[0], 1, 1])

        att_mask = TriangularCausalMask(self.seq_len, self.get_device())
        seq_mask = (x[..., 0] > 0).sum(dim=1).long()
        seq_mask = LengthMask(seq_mask, max_len=self.seq_len)

        x = self.model_projection(x) * math.sqrt(self.project_dimension)

        x = x + relative_pos
        regress_embeddings = self.transformer(x, att_mask, seq_mask)
        pred = self.output_projection(regress_embeddings)

        return pred
    def forward(self, x, lengths=None, attn_kwargs=None):
        attn_mask = TriangularCausalMask(x.size(1), device=x.device)

        if lengths is not None:
            length_mask = LengthMask(lengths, device=x.device)
        else:
            length_mask = None

        attn_kwargs = dict(attn_kwargs) if attn_kwargs else {}

        out = x
        for l in range(self.n_layer):
            # print (out.size())
            out = self.decoder_layers[l](out,
                                         attn_mask=attn_mask,
                                         length_mask=length_mask,
                                         attn_kwargs=attn_kwargs)

        return out
Beispiel #11
0
    def sample(self,
               n_samples,
               x_cond=None,
               y_cond=None,
               encoder_kv=None,
               fp16=False,
               temp=1.0,
               top_k=0,
               top_p=0.0,
               get_preds=False,
               sample_tokens=None):
        assert self.training == False

        if sample_tokens is None: sample_tokens = self.input_dims // 4
        N, D = n_samples, self.input_dims
        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (
                N, 1, self.width
            ), f"Got {x_cond.shape}, expected ({N}, {D}/{1}, {self.width})"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width), dtype=t.float).cuda()

        with t.no_grad():
            xs, x_emb, x = [], [], None
            if get_preds:
                preds = []
            for sample_t in get_range(range(0, sample_tokens)):
                x, cond = self.get_emb(sample_t, n_samples, x, x_cond, y_cond)
                #self.transformer.check_cache(n_samples, sample_t, fp16)
                x_emb.append(x)
                x = t.cat(x_emb, dim=1)
                triangular_mask = TriangularCausalMask(x.shape[1],
                                                       device=x.device)

                x = self.transformer(
                    x, attn_mask=triangular_mask)[:, -1:, :]  # Transformer

                if self.add_cond_after_transformer:
                    x = x + cond
                assert x.shape == (n_samples, 1, self.width)
                x = self.x_out(x)  # Predictions
                if get_preds:
                    preds.append(x.clone())
                # Adjust logits
                x = x / temp
                x = filter_logits(x, top_k=top_k, top_p=top_p)
                x = t.distributions.Categorical(
                    logits=x).sample()  # Sample and replace x
                assert x.shape == (n_samples, 1)
                xs.append(x.clone())

            del x, x_emb

            #self.transformer.del_cache()

            x = t.cat(xs, dim=1)
            if get_preds:
                preds = t.cat(preds, dim=1)
            x = self.postprocess(x, sample_tokens)
        if get_preds:
            return x, preds
        else:
            return x
Beispiel #12
0
    def forward(self,
                x,
                x_cond=None,
                y_cond=None,
                encoder_kv=None,
                fp16=False,
                loss_full=False,
                encode=False,
                get_preds=False,
                get_acts=False,
                get_sep_loss=False):
        # Preprocess.
        with t.no_grad():
            x = self.preprocess(x)

        N, D = x.shape
        assert isinstance(x, t.cuda.LongTensor)
        assert (0 <= x).all() and (x < self.bins).all()

        if self.y_cond:
            assert y_cond is not None
            assert y_cond.shape == (N, 1, self.width)
        else:
            assert y_cond is None

        if self.x_cond:
            assert x_cond is not None
            assert x_cond.shape == (N, D, self.width) or x_cond.shape == (
                N, 1, self.width
            ), f"{x_cond.shape} != {(N, D, self.width)} nor {(N, 1, self.width)}. Did you pass the correct --sample_length?"
        else:
            assert x_cond is None
            x_cond = t.zeros((N, 1, self.width),
                             device=x.device,
                             dtype=t.float)

        x_t = x  # Target
        x = self.x_emb(x)  # X emb
        x = roll(x, 1)  # Shift by 1, and fill in start token
        if self.y_cond:
            x[:, 0] = y_cond.view(N, self.width)
        else:
            x[:, 0] = self.start_token

        x = self.x_emb_dropout(x) + self.pos_emb_dropout(
            self.pos_emb()) + x_cond  # Pos emb and dropout

        triangular_mask = TriangularCausalMask(x.shape[1], device=x.device)

        x = self.transformer(x, attn_mask=triangular_mask)  # Transformer
        if self.add_cond_after_transformer:  # Piped doesnt add x_cond
            x = x + x_cond

        acts = x
        if self.only_encode:
            return x
        x = self.x_out(x)  # Predictions

        if get_sep_loss:
            assert self.prime_len is not None
            x_prime = x[:, :self.prime_len].reshape(-1, self.bins)
            x_gen = x[:, self.prime_len:].reshape(-1, self.bins)

            prime_loss = F.cross_entropy(
                x_prime, x_t[:, :self.prime_len].reshape(-1)) / np.log(2.)
            gen_loss = F.cross_entropy(
                x_gen, x_t[:, self.prime_len:].reshape(-1)) / np.log(2.)

            loss = (prime_loss, gen_loss)  # Note order! Prime is first
        else:
            loss = F.cross_entropy(x.view(-1, self.bins),
                                   x_t.view(-1)) / np.log(2.)  # Loss

        if get_preds:
            return loss, x
        elif get_acts:
            return loss, acts
        else:
            return loss, None