def forward(self, x): """ :param x: (batch_size, num_events, num_channels, input_dim) :return: list of num_channels logits (batch_size, num_events, num_tokens_of_channel) """ x = self.linear_to_input_transformer(x) embedded_seq = flatten(x) batch_size = embedded_seq.size(0) num_tokens = embedded_seq.size(1) num_events = num_tokens // self.num_channels # channel embeddings embedded_seq = torch.cat([ embedded_seq, self.channel_embeddings.repeat(batch_size, num_events, 1) ], dim=2) embedded_seq = embedded_seq.transpose(0, 1) output, _ = self.transformer(embedded_seq) output = output.transpose(0, 1).contiguous() output = output.view(batch_size, num_events, self.num_channels, -1) weights_per_category = [ pre_softmax(t[:, :, 0, :]) for t, pre_softmax in zip(output.split(1, 2), self.pre_softmaxes) ] return weights_per_category
def forward(self, x, corrupt_labels=False): """ :param x: x comes from the dataloader :param corrupt_labels: if true, assign with probability 5% a different label than the computed centroid :return: z_quantized, encoding_indices, quantization_loss """ x_proc = self.data_processor.preprocess(x) x_embed = self.data_processor.embed(x_proc) x_flat = flatten(x_embed) z = self.downscaler.forward(x_flat) z_quantized, encoding_indices, quantization_loss = self.quantizer.forward( z, corrupt_labels=corrupt_labels) if self.upscaler is not None: z_quantized = self.upscaler(z_quantized) return z_quantized, encoding_indices, quantization_loss
def mask_teacher(self, x, num_events_masked): """ :param x: (batch_size, num_events, num_channels) :param num_events_masked: number of events to be masked (before and after) the masked_event_index :return: """ input = flatten(x) batch_size, sequence_length = input.size() num_events = sequence_length // self.num_channels assert sequence_length % self.num_channels == 0 # TODO different masks for different elements in the batch # leave num_events_masked events before and num_events_masked after masked_event_index = torch.randint(high=num_events, size=()).item() # the mask indices are precisely the self.num_notes_per_voice notes_to_be_predicted = torch.zeros_like(input) notes_to_be_predicted[:, masked_event_index * self.num_channels :(masked_event_index + 1) * self.num_channels] = 1 mask_tokens = cuda_variable(torch.LongTensor(self.num_tokens_per_channel)) mask_tokens = mask_tokens.unsqueeze(0).repeat(batch_size, num_events) notes_to_mask = torch.zeros_like(input) notes_to_mask[:, max((masked_event_index - num_events_masked) * self.num_channels, 0) :(masked_event_index + num_events_masked + 1) * self.num_channels] = 1 masked_input = input * (1 - notes_to_mask) + mask_tokens * notes_to_mask # unflatten masked_x = unflatten(masked_input, self.num_channels) notes_to_be_predicted = unflatten(notes_to_be_predicted, self.num_channels) return masked_x, notes_to_be_predicted
def forward(self, source, target): """ :param source: sequence of codebooks (batch_size, s_s) :param target: sequence of tokens (batch_size, num_events, num_channels) :return: """ batch_size = source.size(0) # embed source_seq = self.source_embeddings(source) target = self.data_processor.preprocess(target) target_embedded = self.data_processor.embed(target) target_seq = flatten(target_embedded) num_tokens_target = target_seq.size(1) if self.transformer_type == 'relative': # add positional embeddings target_seq = torch.cat([ target_seq, self.target_channel_embeddings.repeat(batch_size, num_tokens_target // self.num_channels, 1), self.target_events_positioning_embeddings .repeat_interleave(self.num_channels, dim=1) .repeat((batch_size, num_tokens_target // self.total_upscaling, 1)) ], dim=2) elif self.transformer_type == 'absolute': source_seq = torch.cat([ source_seq, self.source_positional_embeddings.repeat(batch_size, 1, 1) ], dim=2) target_seq = torch.cat([ target_seq, self.target_positional_embeddings.repeat(batch_size, 1, 1) ], dim=2) target_seq = self.linear_target(target_seq) # time dim first source_seq = source_seq.transpose(0, 1) target_seq = target_seq.transpose(0, 1) # shift target_seq by one dummy_input = self.sos.repeat(1, batch_size, 1) target_seq = torch.cat( [ dummy_input, target_seq[:-1] ], dim=0) # masks: anti-causal for encoder, causal for decoder source_length = source_seq.size(0) target_length = target_seq.size(0) # cross-masks if self.cross_attention_type in ['diagonal', 'full']: memory_mask = None elif self.cross_attention_type == 'causal': raise NotImplementedError # C'est une galere de ouf, # faut repeat_interleave pour faire des masques rectangulaires # c'est chiant.... elif self.cross_attention_type == 'anticausal': memory_mask = self._generate_anticausal_mask(source_length, target_length) # self-encoder masks if self.encoder_attention_type in ['diagonal', 'full']: source_mask = None elif self.encoder_attention_type == 'causal': source_mask = self._generate_causal_mask(source_length) elif self.encoder_attention_type == 'anticausal': source_mask = self._generate_anticausal_mask(source_length) # Causal target mask target_mask = self._generate_causal_mask(target_length) # for custom output, attentions_decoder, attentions_encoder = self.transformer(source_seq, target_seq, tgt_mask=target_mask, src_mask=source_mask, memory_mask=memory_mask ) output = output.transpose(0, 1).contiguous() output = output.view(batch_size, -1, self.num_channels, self.d_model) weights_per_category = [ pre_softmax(t[:, :, 0, :]) for t, pre_softmax in zip(output.split(1, 2), self.pre_softmaxes) ] # we can change loss mask loss = categorical_crossentropy( value=weights_per_category, target=target, mask=torch.ones_like(target) ) loss = loss.mean() return { 'loss': loss, 'attentions_decoder': attentions_decoder, 'attentions_encoder': attentions_encoder, 'weights_per_category': weights_per_category, 'monitored_quantities': { 'loss': loss.item() } }