示例#1
0
 def _generate_anticausal_mask(self, sz, sz_tgt=None):
     mask = cuda_variable(self._generate_square_subsequent_mask(sz)).t()
     if sz_tgt is not None:
         assert sz_tgt % sz == 0
         subsampling_factor = sz_tgt // sz
         mask = torch.repeat_interleave(mask, subsampling_factor, dim=0)
     return mask
示例#2
0
    def preprocess(self, x):
        """
        Subclasses can only reimplement this method
        This is not necessary

        :param x: ? -> (batch_size, num_events, num_channels)
        :return:
        """
        return cuda_variable(x.long())
示例#3
0
    def forward(self, x):
        """
        :param x: sequence of codebooks (batch_size, s_s)
        :return:
        """
        batch_size = x.size(0)

        target = x.unsqueeze(dim=2)
        # embed
        x_seq = self.embedding(x)
        x_seq = self.linear(x_seq)

        # add positional embeddings
        x_seq = x_seq.transpose(0, 1)

        # shift target_seq by one
        dummy_input = self.sos.repeat(1, batch_size, 1)
        x_seq = torch.cat([dummy_input, x_seq[:-1]], dim=0)

        mask = cuda_variable(
            self._generate_square_subsequent_mask(x_seq.size(0)))

        # for custom
        output, attentions = self.transformer(x_seq, mask=mask)

        output = output.transpose(0, 1).contiguous()

        output = output.view(batch_size, -1, self.num_channels, self.d_model)

        weights_per_category = [
            pre_softmax(t[:, :, 0, :])
            for t, pre_softmax in zip(output.split(1, 2), self.pre_softmaxes)
        ]

        # we can change loss mask
        loss = categorical_crossentropy(value=weights_per_category,
                                        target=target,
                                        mask=torch.ones_like(target))

        loss = loss.mean()
        return {
            'loss': loss,
            'weights_per_category': weights_per_category,
            'monitored_quantities': {
                'loss': loss.item()
            }
        }
示例#4
0
    def mask_teacher(self, x, num_events_masked):
        """

        :param x: (batch_size, num_events, num_channels)
        :param num_events_masked: number of events to be masked (before and after) the
        masked_event_index
        :return:
        """
        input = flatten(x)
        batch_size, sequence_length = input.size()
        num_events = sequence_length // self.num_channels
        assert sequence_length % self.num_channels == 0

        # TODO different masks for different elements in the batch
        # leave num_events_masked events before and num_events_masked after
        masked_event_index = torch.randint(high=num_events,
                                           size=()).item()

        # the mask indices are precisely the self.num_notes_per_voice
        notes_to_be_predicted = torch.zeros_like(input)

        notes_to_be_predicted[:,
        masked_event_index * self.num_channels
        :(masked_event_index + 1) * self.num_channels] = 1

        mask_tokens = cuda_variable(torch.LongTensor(self.num_tokens_per_channel))
        mask_tokens = mask_tokens.unsqueeze(0).repeat(batch_size, num_events)

        notes_to_mask = torch.zeros_like(input)
        notes_to_mask[:,
        max((masked_event_index - num_events_masked) * self.num_channels, 0)
        :(masked_event_index + num_events_masked + 1) * self.num_channels] = 1

        masked_input = input * (1 - notes_to_mask) + mask_tokens * notes_to_mask

        # unflatten
        masked_x = unflatten(masked_input,
                             self.num_channels)
        notes_to_be_predicted = unflatten(notes_to_be_predicted,
                                          self.num_channels)
        return masked_x, notes_to_be_predicted
    def preprocess(self, x):
        """
        Preprocess a dcpc block

        :param x: (..., num_ticks, num_voices) of appropriate dimensions
        :return: (..., num_blocks, num_tokens_per_block)
        """
        # if flat_input:

        num_ticks, num_voices = x.size()[-2:]
        remaining_dims = x.size()[:-2]

        x = x.view(-1, num_ticks, num_voices).contiguous()
        x = x.view(-1, num_voices * num_ticks)

        assert x.size(1) % self.num_tokens_per_block == 0
        x = x.split(self.num_tokens_per_block, dim=1)
        x = torch.cat([t.unsqueeze(1) for t in x], dim=1)

        num_blocks = x.size(1)
        x = x.view(*remaining_dims, num_blocks, self.num_tokens_per_block)
        return cuda_variable(x.long())
示例#6
0
 def init_generation(self, num_events):
     return cuda_variable(
         torch.zeros(1, num_events, self.num_channels).long()
     )
示例#7
0
    def generate(self, temperature,
                 batch_size=1,
                 top_k=0,
                 top_p=1.,
                 seed_set=None,
                 exclude_meta_symbols=False,
                 plot_attentions=False,
                 code_juxtaposition=False):
        self.eval()
        (generator_train, generator_val, _) = self.dataloader_generator.dataloaders(
            batch_size=1,
            shuffle_val=True
        )

        with torch.no_grad():
            if code_juxtaposition:
                # Use the codes of a chorale for the first half, and the codes from another chorale for the last half
                if seed_set == 'val':
                    tensor_dict_beginning = next(iter(generator_val))
                    tensor_dict_end = next(iter(generator_val))
                elif seed_set == 'train':
                    tensor_dict_beginning = next(iter(generator_train))
                    tensor_dict_end = next(iter(generator_train))
                else:
                    raise Exception('Need to indicate seeds dataset')

                num_events_chorale_half = tensor_dict_beginning['x'].shape[1] // 2
                x_beg = tensor_dict_beginning['x'][:, :num_events_chorale_half]
                x_end = tensor_dict_end['x'][:, num_events_chorale_half:]
                x_original_single = torch.cat([x_beg, x_end], dim=1)
                x_original = x_original_single.repeat(batch_size, 1, 1)
            else:
                if seed_set == 'val':
                    tensor_dict = next(iter(generator_val))
                elif seed_set == 'train':
                    tensor_dict = next(iter(generator_train))
                else:
                    raise Exception('Need to indicate seeds dataset')

                x_original_single = tensor_dict['x']
                x_original = x_original_single.repeat(batch_size, 1, 1)

            # compute downscaled version
            zs, encoding_indices, _ = self.encoder(x_original)
            if encoding_indices is None:
                # if no quantization is used, directly use the zs
                encoding_indices = zs
            else:
                encoding_indices = self.encoder.merge_codes(encoding_indices)

            x = self.init_generation(num_events=self.data_processor.num_events)

            # Duplicate along batch dimension
            x = x.repeat(batch_size, 1, 1)

            attentions_decoder_list = []
            attentions_encoder_list = []
            attentions_cross_list = []

            for event_index in range(self.data_processor.num_events):
                for channel_index in range(self.num_channels):
                    forward_pass = self.forward(encoding_indices,
                                                x)

                    weights_per_voice = forward_pass['weights_per_category']
                    weights = weights_per_voice[channel_index]

                    # Keep only the last token predictions of the first batch item (batch size 1), apply a
                    # temperature coefficient and filter
                    logits = weights[:, event_index, :] / temperature

                    # Remove meta symbols
                    if exclude_meta_symbols:
                        for sym in [START_SYMBOL, END_SYMBOL, PAD_SYMBOL]:
                            sym_index = \
                                self.dataloader_generator.dataset.note2index_dicts[channel_index][
                                    sym]
                            logits[:, sym_index] = -float("inf")

                    # Top-p sampling
                    filtered_logits = []
                    for logit in logits:
                        filter_logit = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p)
                        filtered_logits.append(filter_logit)
                    filtered_logits = torch.stack(filtered_logits, dim=0)
                    # Sample from the filtered distribution
                    p = to_numpy(torch.softmax(filtered_logits, dim=-1))

                    # update generated sequence
                    for batch_index in range(batch_size):
                        new_pitch_index = np.random.choice(np.arange(
                            self.num_tokens_per_channel[channel_index]
                        ), p=p[batch_index])
                        x[batch_index, event_index, channel_index] = int(new_pitch_index)

                    # store attentions
                    if plot_attentions:
                        layer = 2
                        event_index_encoder = (
                                                      event_index * self.num_channels) // self.total_upscaling
                        attentions_encoder = forward_pass['attentions_encoder']
                        # list of dicts with key 'a_self_encoder'
                        attentions_decoder = forward_pass['attentions_decoder']
                        # list of dicts with keys 'a_self_decoder' and 'a_cross'

                        # get attentions at corresponding event
                        attn_encoder = attentions_encoder[layer]['a_self_encoder'][:, :,
                                       event_index_encoder, :]
                        attn_decoder = attentions_decoder[layer]['a_self_decoder'][:, :,
                                       event_index * self.num_channels + channel_index, :]
                        attn_cross = attentions_decoder[layer]['a_cross'][:, :,
                                     event_index * self.num_channels + channel_index, :]

                        attentions_encoder_list.append(attn_encoder)
                        attentions_decoder_list.append(attn_decoder)
                        attentions_cross_list.append(attn_cross)

            # Compute codes for generations
            x_re_encode = torch.cat([
                cuda_variable(x_original_single.long()),
                x
            ], dim=0)
            _, recoding_, _ = self.encoder(x_re_encode)
            if recoding_ is not None:
                recoding_ = recoding_.detach().cpu().numpy()
                recoding = self.encoder.merge_codes(recoding_)
            else:
                recoding = None

        # to score
        original_and_reconstruction = self.data_processor.postprocess(original=x_original.long(),
                                                                      reconstruction=x.cpu())

        ###############################
        # Saving
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        if code_juxtaposition:
            save_dir = f'{self.model_dir}/juxtapositions'
        else:
            save_dir = f'{self.model_dir}/generations'
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)

        # Write code sequence
        if recoding is not None:
            with open(f'{save_dir}/{timestamp}.txt', 'w') as ff:
                for batch_ind in range(len(recoding)):
                    aa = recoding[batch_ind]
                    ff.write(' , '.join(map(str, list(aa))))
                    ff.write('\n')

        # Write scores
        scores = []
        for k, tensor_score in enumerate(original_and_reconstruction):
            path_no_extension = f'{save_dir}/{timestamp}_{k}'
            scores.append(self.dataloader_generator.write(tensor_score, path_no_extension))
        print(f'Saved in {save_dir}/{timestamp}')
        ###############################

        if plot_attentions:
            self.plot_attention(attentions_cross_list,
                                timestamp=timestamp,
                                name='attns_cross')
            self.plot_attention(attentions_encoder_list,
                                timestamp=timestamp,
                                name='self_attns_encoder')
            self.plot_attention(attentions_decoder_list,
                                timestamp=timestamp,
                                name='self_attns_decoder')

        return scores
示例#8
0
 def _generate_causal_mask(self, sz):
     return cuda_variable(self._generate_square_subsequent_mask(sz))
示例#9
0
    def forward(self, inputs, corrupt_labels=False, **kwargs):

        input_shape = inputs.size()

        # Normalize and flatten
        if self.use_batch_norm:

            flat_input = inputs.view(-1, self.codebook_dim).unsqueeze(1)

            flat_input = flat_input.permute(0, 2, 1)
            flat_input = self.batch_norm(flat_input)
            flat_input = flat_input.permute(0, 2, 1).contiguous()
            flat_input = flat_input[:, 0, :]
        else:
            flat_input = inputs.view(-1, self.codebook_dim)

        if self.initialize:
            self._initialize(flat_input=flat_input)

        # Calculate distances
        distances = [(torch.sum(input_component**2, dim=1, keepdim=True) +
                      torch.sum(embedding**2, dim=1) -
                      2 * torch.matmul(input_component, embedding.t()))
                     for input_component, embedding in zip(
                         flat_input.chunk(chunks=self.num_codebooks, dim=1),
                         self.embeddings)]

        # Encoding
        encoding_indices_list = [
            torch.argmin(distance, dim=1).unsqueeze(1)
            for distance in distances
        ]

        # corrupt indices
        if self.training and corrupt_labels:
            random_indices_list = [
                torch.randint_like(encoding_indices_list[0],
                                   low=0,
                                   high=self.codebook_size)
                for _ in range(self.num_codebooks)
            ]
            mask_list = [
                (torch.rand_like(random_indices.float()) > 0.05).long()
                for random_indices in random_indices_list
            ]
            encoding_indices_list = [
                mask * encoding_indices + (1 - mask) * random_indices
                for encoding_indices, random_indices, mask in zip(
                    encoding_indices_list, random_indices_list, mask_list)
            ]

        encodings = [
            cuda_variable(
                torch.zeros(encoding_indices.shape[0], self.codebook_size))
            for encoding_indices in encoding_indices_list
        ]
        for encoding, encoding_indices in zip(encodings,
                                              encoding_indices_list):
            encoding.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized_list = [
            torch.matmul(encoding, embedding)
            for encoding, embedding in zip(encodings, self.embeddings)
        ]
        quantized = torch.cat(quantized_list, dim=1).view(input_shape)

        quantization_loss = self._loss(inputs, quantized)

        quantized_sg = inputs + (quantized - inputs).detach()

        # encoding_indices = torch.zeros_like(encoding_indices_list[0])
        # for encoding_index in encoding_indices_list:
        #     encoding_indices = encoding_indices * self.codebook_size + encoding_index
        # print(len(torch.unique(encoding_indices)))
        # encoding_indices = encoding_indices.view(input_shape[:-1])

        encoding_indices_shape = list(input_shape[:-1]) + [-1]
        encoding_indices = torch.stack(encoding_indices_list,
                                       dim=-1).view(encoding_indices_shape)

        return quantized_sg, encoding_indices, quantization_loss
示例#10
0
 def forward(self, inputs, **kwargs):
     loss = cuda_variable(torch.zeros_like(inputs)).sum(dim=-1)
     quantized_sg = inputs
     encoding_indices = None
     return quantized_sg, encoding_indices, loss
    def forward(self, q):
        """

        :param q: (batch_size * num_heads, len_q_tgt, d)
        :return:
        """
        sz_b_times_n_head, len_q, d_q = q.size()
        assert sz_b_times_n_head % self.num_heads == 0
        sz_b = sz_b_times_n_head // self.num_heads

        batch_size = sz_b_times_n_head

        ################################
        # Causal
        e1 = self.e1.unsqueeze(0).repeat(sz_b, 1, 1)
        e1 = e1.view(sz_b * self.num_heads, self.seq_len_src, d_q)
        rel_attn_1 = torch.einsum('bld,bmd->blm', (q, e1))
        # tgt * src -> src * tgt
        rel_attn_1 = rel_attn_1.view(batch_size, self.seq_len_src,
                                     self.seq_len_tgt)

        #  one column padding on dim 2
        rel_attn_1 = torch.cat([
            cuda_variable(torch.ones(batch_size, self.seq_len_src, 1) * -100),
            rel_attn_1,
        ],
                               dim=2)

        #  fill in with lines (ensure view can be done)
        bottom_extension = self.seq_len_tgt - self.seq_len_src
        if bottom_extension != 0:
            rel_attn_1 = torch.cat([
                rel_attn_1,
                cuda_variable(
                    torch.ones(batch_size, bottom_extension,
                               self.seq_len_tgt + 1) * -100),
            ],
                                   dim=1)

        #  skewing
        rel_attn_1 = rel_attn_1.view(batch_size, -1, self.seq_len_src)
        #  need to remove first line here
        rel_attn_1 = rel_attn_1[:, 1:]
        rel_attn_1 = rel_attn_1[:, :self.seq_len_tgt, :]
        ################################

        ################################
        #  Anticausal
        e2 = self.e2.unsqueeze(0).repeat(sz_b, 1, 1)
        e2 = e2.view(sz_b * self.num_heads, self.seq_len_src, d_q)
        rel_attn_2 = torch.einsum('bld,bmd->blm', (q, e2))

        batch_size = rel_attn_2.size(0)

        # tgt * src -> src * tgt
        rel_attn_2 = rel_attn_2.view(batch_size, self.seq_len_src,
                                     self.seq_len_tgt)

        #  one column padding on dim 2
        rel_attn_2 = torch.cat([
            rel_attn_2,
            cuda_variable(torch.ones(batch_size, self.seq_len_src, 1) * -100),
        ],
                               dim=2)

        #  fill in with lines (ensure view can be done)
        bottom_extension = self.seq_len_tgt - self.seq_len_src
        if bottom_extension != 0:
            rel_attn_2 = torch.cat([
                rel_attn_2,
                cuda_variable(
                    torch.ones(batch_size, bottom_extension,
                               self.seq_len_tgt + 1) * -100),
            ],
                                   dim=1)

        #  SKEWWWIIIIING (tgt + 1) * (tgt + 1) -> x * tgt
        rel_attn_2 = rel_attn_2.view(batch_size, -1, self.seq_len_src)
        rel_attn_2 = rel_attn_2[:, :self.seq_len_tgt, :]
        ################################

        #  mask causal and anticausal
        # Using ones_like is faster than cuda_variable(ones(...))
        masks_down = torch.triu(torch.ones_like(
            rel_attn_1[0, :self.seq_len_src, :self.seq_len_src]).byte(),
                                diagonal=0).unsqueeze(0).repeat(
                                    sz_b_times_n_head, 1,
                                    1).flip(1).flip(2).type(torch.bool)
        if self.subsampling_ratio != 1:
            masks_down = torch.repeat_interleave(masks_down,
                                                 self.subsampling_ratio,
                                                 dim=1)

        masks_up = torch.triu(torch.ones_like(
            rel_attn_1[0, :self.seq_len_src, :self.seq_len_src]).byte(),
                              diagonal=1).unsqueeze(0).repeat(
                                  sz_b_times_n_head, 1, 1).type(torch.bool)
        if self.subsampling_ratio != 1:
            masks_up = torch.repeat_interleave(masks_up,
                                               self.subsampling_ratio,
                                               dim=1)

        rel_attn_1 = rel_attn_1.masked_fill(masks_up, 0)
        rel_attn_2 = rel_attn_2.masked_fill(masks_down, 0)
        rel_attn = rel_attn_1 + rel_attn_2
        return rel_attn
                                                 dim=1)

        masks_up = torch.triu(torch.ones_like(
            rel_attn_1[0, :self.seq_len_src, :self.seq_len_src]).byte(),
                              diagonal=1).unsqueeze(0).repeat(
                                  sz_b_times_n_head, 1, 1).type(torch.bool)
        if self.subsampling_ratio != 1:
            masks_up = torch.repeat_interleave(masks_up,
                                               self.subsampling_ratio,
                                               dim=1)

        rel_attn_1 = rel_attn_1.masked_fill(masks_up, 0)
        rel_attn_2 = rel_attn_2.masked_fill(masks_down, 0)
        rel_attn = rel_attn_1 + rel_attn_2
        return rel_attn


if __name__ == '__main__':
    batch_size = 1
    head_dim = 2
    num_heads = 1
    seq_len_src = 6
    seq_len_tgt = 6
    aa = SubsampledRelativeAttention(head_dim, num_heads, seq_len_src,
                                     seq_len_tgt)
    aa.to('cuda')
    q = cuda_variable(
        torch.ones((batch_size * num_heads, seq_len_tgt, head_dim)))
    ret = aa.forward(q)
    exit()
示例#13
0
    def generate(
        self,
        num_tokens,
        decoder,
        temperature=1.,
        num_generated_codes=1,
        num_decodings_per_generating_code=1,
    ):
        self.eval()
        decoder.eval()
        with torch.no_grad():
            # init
            x = cuda_variable(torch.zeros(1, num_tokens,
                                          self.num_channels)).long()

            x = x.repeat(num_generated_codes, 1, 1)
            assert num_tokens % self.num_channels == 0
            # num_tokens is the number of the sequence to be generated
            # while self.num_tokens is the number of tokens of the input of the model
            assert num_tokens >= self.num_tokens
            num_events = num_tokens // self.num_channels

            for event_index in range(num_events):
                for channel_index in range(self.num_channels):
                    # removes channel dim
                    x_input = x[:, :, 0]
                    if event_index >= self.num_tokens:
                        x_input = x_input[:, event_index - self.num_tokens +
                                          1:event_index + 1]
                        event_offset = event_index - self.num_tokens + 1
                    else:
                        x_input = x_input[:, :self.num_tokens]
                        event_offset = 0

                    weights_per_voice = self.forward(
                        x_input)['weights_per_category']

                    weights = weights_per_voice[channel_index]
                    probs = torch.softmax(weights[:, event_index -
                                                  event_offset, :],
                                          dim=1)
                    p = to_numpy(probs)
                    # temperature ?!
                    p = np.exp(np.log(p + 1e-20) * temperature)
                    p = p / p.sum(axis=1, keepdims=True)

                    for batch_index in range(num_generated_codes):
                        new_pitch_index = np.random.choice(np.arange(
                            self.num_tokens_per_channel[channel_index]),
                                                           p=p[batch_index])
                        x[batch_index, event_index,
                          channel_index] = int(new_pitch_index)

        source = x[:, :, 0]
        scores = decoder.generate_from_code_long(
            encoding_indices=source,
            temperature=temperature,
            num_decodings=num_decodings_per_generating_code)

        # save scores in model_dir
        timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        if not os.path.exists(f'{self.model_dir}/generations'):
            os.mkdir(f'{self.model_dir}/generations')

        for k, score in enumerate(scores):
            score.write('xml',
                        f'{self.model_dir}/generations/{timestamp}_{k}.xml')

        return scores