Beispiel #1
0
    def init_generation_chorale(self, num_events, start_index, melody_constraint=None):
        PAD = [d[PAD_SYMBOL] for d in self.train_dataloader_generator.dataset.note2index_dicts]
        START = [d[START_SYMBOL] for d in self.train_dataloader_generator.dataset.note2index_dicts]
        END = [d[END_SYMBOL] for d in self.train_dataloader_generator.dataset.note2index_dicts]
        aa = torch.Tensor(PAD).unsqueeze(0).unsqueeze(0).repeat(1, start_index - 1, 1).long()
        bb = torch.Tensor(START).unsqueeze(0).unsqueeze(0).long().repeat(1, num_events - 2 * start_index + 1, 1).long()
        cc = torch.Tensor(END).unsqueeze(0).unsqueeze(0).long()
        dd = torch.Tensor(PAD).unsqueeze(0).unsqueeze(0).repeat(1, start_index - 1, 1).long()
        init_sequence = torch.cat([aa, bb, cc, dd], 1)

        masked_positions = torch.ones_like(init_sequence)
        masked_positions[:, : start_index, :] = 0
        masked_positions[:, -start_index:, :] = 0

        if melody_constraint is not None:
            MELODY_CONSTRAINT = [self.train_dataloader_generator.dataset.note2index_dicts[0][note]
                                 for note in melody_constraint]
            for i in range(num_events - 2 * start_index):
                init_sequence[:, i + start_index, 0] = MELODY_CONSTRAINT[i]
                masked_positions[:, i + start_index, 0] = 0
        return cuda_variable(init_sequence), cuda_variable(masked_positions)
    def forward(self, q):
        """

        :param q: (batch_size * num_heads, len_q, d)
        :return:
        """
        sz_b_times_n_head, len_q, d_q = q.size()
        assert sz_b_times_n_head % self.num_heads == 0
        sz_b = sz_b_times_n_head // self.num_heads

        # the trick must be done twice in case attention is full (not causal or anticausal)
        e1 = self.e1.unsqueeze(0).repeat(sz_b, 1, 1)

        # WARNING: len_q can be different from self.max_seq_len
        e1 = e1.view(sz_b * self.num_heads, self.max_seq_len, d_q)
        e1 = e1[:, :len_q]

        rel_attn_1 = torch.einsum('bld,bmd->blm',
                                  (q, e1))
        e2 = self.e2.unsqueeze(0).repeat(sz_b, 1, 1)
        # WARNING: len_q can be different from self.max_seq_len
        # TODO allow len_q > self.max_seq_len
        e2 = e2.view(sz_b * self.num_heads, self.max_seq_len, d_q)
        e2 = e2[:, :len_q]

        rel_attn_2 = torch.einsum('bld,bmd->blm',
                                  (q, e2))

        batch_size, l, _ = rel_attn_1.size()
        # ====skewing trick
        # ----Down
        # pad
        rel_attn_1 = torch.cat(
            [rel_attn_1,
             cuda_variable(torch.ones(1, 1, 1, ) * - 100).repeat(batch_size, l, 1),
             ], dim=2
        )
        rel_attn_1 = rel_attn_1.view(batch_size, l + 1, l)

        rel_attn_1 = rel_attn_1[:, :-1, :]

        # ----Up

        # pad
        # extension = cuda_variable(torch.ones(batch_size, l, 1, ) * - 100)
        rel_attn_2 = torch.cat(
            [cuda_variable(torch.ones(1, 1, 1, ) * - 100).repeat(batch_size, l, 1),
             rel_attn_2
             ], dim=2
        )
        rel_attn_2 = rel_attn_2.view(batch_size,
                                     l + 1,
                                     l,
                                     )

        rel_attn_2 = rel_attn_2[:, 1:, :]

        masks_down = torch.triu(torch.ones_like(rel_attn_1[0]).byte(),
                                diagonal=0).unsqueeze(0).repeat(sz_b_times_n_head, 1, 1).flip(
            1).flip(2).type(torch.bool)
        masks_up = torch.triu(torch.ones_like(rel_attn_2[0]).byte(),
                              diagonal=1).unsqueeze(0).repeat(sz_b_times_n_head, 1, 1).type(
            torch.bool)

        rel_attn_1 = rel_attn_1.masked_fill(masks_down, 0)
        rel_attn_2 = rel_attn_2.masked_fill(masks_up, 0)
        rel_attn = rel_attn_1 + rel_attn_2
        return rel_attn
Beispiel #3
0
    def forward(self, x, masked_positions=None):
        """
        :param x: sequence of tokens (batch_size, num_events, num_channels)
        :param masked_positions: (batch_size, num_events, num_channels)
        None to get random masked_positions
        1 in masked positions when masked
        :return:
        """
        batch_size = x.size(0)

        # embed
        target = self.data_processor.preprocess(x)
        target_embedded = self.data_processor.embed(target)
        target_seq = flatten(target_embedded)

        # compute masked_x
        if masked_positions is not None:
            masked_positions = flatten(masked_positions)
        source_seq = self.mask(target_seq, masked_positions)

        # add positional embeddings and to d_model
        target_seq = self.add_positional_embedding(target_seq)
        target_seq = self.linear_target(target_seq)

        source_seq = self.add_positional_embedding(source_seq)
        source_seq = self.linear_target(source_seq)

        source_seq = source_seq.transpose(0, 1)
        target_seq = target_seq.transpose(0, 1)

        # shift target_seq by one
        dummy_input = self.sos.repeat(1, batch_size, 1)
        target_seq = torch.cat(
            [
                dummy_input,
                target_seq[:-1]
            ],
            dim=0)

        target_mask = cuda_variable(
            self._generate_square_subsequent_mask(target_seq.size(0))
        )
        memory_mask = target_mask + target_mask.t()
        source_mask = target_mask.t()

        output, attentions_decoder, attentions_encoder = self.transformer(
            source_seq,
            target_seq,
            src_mask=source_mask,
            tgt_mask=target_mask,
            memory_mask=memory_mask
        )

        output = output.transpose(0, 1).contiguous()

        output = output.view(batch_size,
                             -1,
                             self.num_channels,
                             self.d_model)
        weights_per_category = [
            pre_softmax(t[:, :, 0, :])
            for t, pre_softmax in zip(output.split(1, 2), self.pre_softmaxes)
        ]

        # we can change loss mask
        loss = categorical_crossentropy(
            value=weights_per_category,
            target=target,
            mask=torch.ones_like(target)
        )

        loss = loss.mean()
        return {
            'loss':                 loss,
            'attentions_decoder':   attentions_decoder,
            'attentions_encoder':   attentions_encoder,
            'weights_per_category': weights_per_category,
            'monitored_quantities': {
                'loss': loss.item()
            }
        }
        )
        rel_attn_2 = rel_attn_2.view(batch_size,
                                     l + 1,
                                     l,
                                     )

        rel_attn_2 = rel_attn_2[:, 1:, :]

        masks_down = torch.triu(torch.ones_like(rel_attn_1[0]).byte(),
                                diagonal=0).unsqueeze(0).repeat(sz_b_times_n_head, 1, 1).flip(
            1).flip(2).type(torch.bool)
        masks_up = torch.triu(torch.ones_like(rel_attn_2[0]).byte(),
                              diagonal=1).unsqueeze(0).repeat(sz_b_times_n_head, 1, 1).type(
            torch.bool)

        rel_attn_1 = rel_attn_1.masked_fill(masks_down, 0)
        rel_attn_2 = rel_attn_2.masked_fill(masks_up, 0)
        rel_attn = rel_attn_1 + rel_attn_2
        return rel_attn


if __name__ == '__main__':
    batch_size = 1
    head_dim = 2
    num_heads = 1
    seq_len = 6
    aa = RelativeAttention(head_dim, num_heads, seq_len)
    q = cuda_variable(torch.ones((batch_size * num_heads, seq_len, head_dim)))
    ret = aa.forward(q)
    exit()