def _generate_anticausal_mask(self, sz, sz_tgt=None): mask = cuda_variable(self._generate_square_subsequent_mask(sz)).t() if sz_tgt is not None: assert sz_tgt % sz == 0 subsampling_factor = sz_tgt // sz mask = torch.repeat_interleave(mask, subsampling_factor, dim=0) return mask
def preprocess(self, x): """ Subclasses can only reimplement this method This is not necessary :param x: ? -> (batch_size, num_events, num_channels) :return: """ return cuda_variable(x.long())
def forward(self, x): """ :param x: sequence of codebooks (batch_size, s_s) :return: """ batch_size = x.size(0) target = x.unsqueeze(dim=2) # embed x_seq = self.embedding(x) x_seq = self.linear(x_seq) # add positional embeddings x_seq = x_seq.transpose(0, 1) # shift target_seq by one dummy_input = self.sos.repeat(1, batch_size, 1) x_seq = torch.cat([dummy_input, x_seq[:-1]], dim=0) mask = cuda_variable( self._generate_square_subsequent_mask(x_seq.size(0))) # for custom output, attentions = self.transformer(x_seq, mask=mask) output = output.transpose(0, 1).contiguous() output = output.view(batch_size, -1, self.num_channels, self.d_model) weights_per_category = [ pre_softmax(t[:, :, 0, :]) for t, pre_softmax in zip(output.split(1, 2), self.pre_softmaxes) ] # we can change loss mask loss = categorical_crossentropy(value=weights_per_category, target=target, mask=torch.ones_like(target)) loss = loss.mean() return { 'loss': loss, 'weights_per_category': weights_per_category, 'monitored_quantities': { 'loss': loss.item() } }
def mask_teacher(self, x, num_events_masked): """ :param x: (batch_size, num_events, num_channels) :param num_events_masked: number of events to be masked (before and after) the masked_event_index :return: """ input = flatten(x) batch_size, sequence_length = input.size() num_events = sequence_length // self.num_channels assert sequence_length % self.num_channels == 0 # TODO different masks for different elements in the batch # leave num_events_masked events before and num_events_masked after masked_event_index = torch.randint(high=num_events, size=()).item() # the mask indices are precisely the self.num_notes_per_voice notes_to_be_predicted = torch.zeros_like(input) notes_to_be_predicted[:, masked_event_index * self.num_channels :(masked_event_index + 1) * self.num_channels] = 1 mask_tokens = cuda_variable(torch.LongTensor(self.num_tokens_per_channel)) mask_tokens = mask_tokens.unsqueeze(0).repeat(batch_size, num_events) notes_to_mask = torch.zeros_like(input) notes_to_mask[:, max((masked_event_index - num_events_masked) * self.num_channels, 0) :(masked_event_index + num_events_masked + 1) * self.num_channels] = 1 masked_input = input * (1 - notes_to_mask) + mask_tokens * notes_to_mask # unflatten masked_x = unflatten(masked_input, self.num_channels) notes_to_be_predicted = unflatten(notes_to_be_predicted, self.num_channels) return masked_x, notes_to_be_predicted
def preprocess(self, x): """ Preprocess a dcpc block :param x: (..., num_ticks, num_voices) of appropriate dimensions :return: (..., num_blocks, num_tokens_per_block) """ # if flat_input: num_ticks, num_voices = x.size()[-2:] remaining_dims = x.size()[:-2] x = x.view(-1, num_ticks, num_voices).contiguous() x = x.view(-1, num_voices * num_ticks) assert x.size(1) % self.num_tokens_per_block == 0 x = x.split(self.num_tokens_per_block, dim=1) x = torch.cat([t.unsqueeze(1) for t in x], dim=1) num_blocks = x.size(1) x = x.view(*remaining_dims, num_blocks, self.num_tokens_per_block) return cuda_variable(x.long())
def init_generation(self, num_events): return cuda_variable( torch.zeros(1, num_events, self.num_channels).long() )
def generate(self, temperature, batch_size=1, top_k=0, top_p=1., seed_set=None, exclude_meta_symbols=False, plot_attentions=False, code_juxtaposition=False): self.eval() (generator_train, generator_val, _) = self.dataloader_generator.dataloaders( batch_size=1, shuffle_val=True ) with torch.no_grad(): if code_juxtaposition: # Use the codes of a chorale for the first half, and the codes from another chorale for the last half if seed_set == 'val': tensor_dict_beginning = next(iter(generator_val)) tensor_dict_end = next(iter(generator_val)) elif seed_set == 'train': tensor_dict_beginning = next(iter(generator_train)) tensor_dict_end = next(iter(generator_train)) else: raise Exception('Need to indicate seeds dataset') num_events_chorale_half = tensor_dict_beginning['x'].shape[1] // 2 x_beg = tensor_dict_beginning['x'][:, :num_events_chorale_half] x_end = tensor_dict_end['x'][:, num_events_chorale_half:] x_original_single = torch.cat([x_beg, x_end], dim=1) x_original = x_original_single.repeat(batch_size, 1, 1) else: if seed_set == 'val': tensor_dict = next(iter(generator_val)) elif seed_set == 'train': tensor_dict = next(iter(generator_train)) else: raise Exception('Need to indicate seeds dataset') x_original_single = tensor_dict['x'] x_original = x_original_single.repeat(batch_size, 1, 1) # compute downscaled version zs, encoding_indices, _ = self.encoder(x_original) if encoding_indices is None: # if no quantization is used, directly use the zs encoding_indices = zs else: encoding_indices = self.encoder.merge_codes(encoding_indices) x = self.init_generation(num_events=self.data_processor.num_events) # Duplicate along batch dimension x = x.repeat(batch_size, 1, 1) attentions_decoder_list = [] attentions_encoder_list = [] attentions_cross_list = [] for event_index in range(self.data_processor.num_events): for channel_index in range(self.num_channels): forward_pass = self.forward(encoding_indices, x) weights_per_voice = forward_pass['weights_per_category'] weights = weights_per_voice[channel_index] # Keep only the last token predictions of the first batch item (batch size 1), apply a # temperature coefficient and filter logits = weights[:, event_index, :] / temperature # Remove meta symbols if exclude_meta_symbols: for sym in [START_SYMBOL, END_SYMBOL, PAD_SYMBOL]: sym_index = \ self.dataloader_generator.dataset.note2index_dicts[channel_index][ sym] logits[:, sym_index] = -float("inf") # Top-p sampling filtered_logits = [] for logit in logits: filter_logit = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) filtered_logits.append(filter_logit) filtered_logits = torch.stack(filtered_logits, dim=0) # Sample from the filtered distribution p = to_numpy(torch.softmax(filtered_logits, dim=-1)) # update generated sequence for batch_index in range(batch_size): new_pitch_index = np.random.choice(np.arange( self.num_tokens_per_channel[channel_index] ), p=p[batch_index]) x[batch_index, event_index, channel_index] = int(new_pitch_index) # store attentions if plot_attentions: layer = 2 event_index_encoder = ( event_index * self.num_channels) // self.total_upscaling attentions_encoder = forward_pass['attentions_encoder'] # list of dicts with key 'a_self_encoder' attentions_decoder = forward_pass['attentions_decoder'] # list of dicts with keys 'a_self_decoder' and 'a_cross' # get attentions at corresponding event attn_encoder = attentions_encoder[layer]['a_self_encoder'][:, :, event_index_encoder, :] attn_decoder = attentions_decoder[layer]['a_self_decoder'][:, :, event_index * self.num_channels + channel_index, :] attn_cross = attentions_decoder[layer]['a_cross'][:, :, event_index * self.num_channels + channel_index, :] attentions_encoder_list.append(attn_encoder) attentions_decoder_list.append(attn_decoder) attentions_cross_list.append(attn_cross) # Compute codes for generations x_re_encode = torch.cat([ cuda_variable(x_original_single.long()), x ], dim=0) _, recoding_, _ = self.encoder(x_re_encode) if recoding_ is not None: recoding_ = recoding_.detach().cpu().numpy() recoding = self.encoder.merge_codes(recoding_) else: recoding = None # to score original_and_reconstruction = self.data_processor.postprocess(original=x_original.long(), reconstruction=x.cpu()) ############################### # Saving timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if code_juxtaposition: save_dir = f'{self.model_dir}/juxtapositions' else: save_dir = f'{self.model_dir}/generations' if not os.path.exists(save_dir): os.mkdir(save_dir) # Write code sequence if recoding is not None: with open(f'{save_dir}/{timestamp}.txt', 'w') as ff: for batch_ind in range(len(recoding)): aa = recoding[batch_ind] ff.write(' , '.join(map(str, list(aa)))) ff.write('\n') # Write scores scores = [] for k, tensor_score in enumerate(original_and_reconstruction): path_no_extension = f'{save_dir}/{timestamp}_{k}' scores.append(self.dataloader_generator.write(tensor_score, path_no_extension)) print(f'Saved in {save_dir}/{timestamp}') ############################### if plot_attentions: self.plot_attention(attentions_cross_list, timestamp=timestamp, name='attns_cross') self.plot_attention(attentions_encoder_list, timestamp=timestamp, name='self_attns_encoder') self.plot_attention(attentions_decoder_list, timestamp=timestamp, name='self_attns_decoder') return scores
def _generate_causal_mask(self, sz): return cuda_variable(self._generate_square_subsequent_mask(sz))
def forward(self, inputs, corrupt_labels=False, **kwargs): input_shape = inputs.size() # Normalize and flatten if self.use_batch_norm: flat_input = inputs.view(-1, self.codebook_dim).unsqueeze(1) flat_input = flat_input.permute(0, 2, 1) flat_input = self.batch_norm(flat_input) flat_input = flat_input.permute(0, 2, 1).contiguous() flat_input = flat_input[:, 0, :] else: flat_input = inputs.view(-1, self.codebook_dim) if self.initialize: self._initialize(flat_input=flat_input) # Calculate distances distances = [(torch.sum(input_component**2, dim=1, keepdim=True) + torch.sum(embedding**2, dim=1) - 2 * torch.matmul(input_component, embedding.t())) for input_component, embedding in zip( flat_input.chunk(chunks=self.num_codebooks, dim=1), self.embeddings)] # Encoding encoding_indices_list = [ torch.argmin(distance, dim=1).unsqueeze(1) for distance in distances ] # corrupt indices if self.training and corrupt_labels: random_indices_list = [ torch.randint_like(encoding_indices_list[0], low=0, high=self.codebook_size) for _ in range(self.num_codebooks) ] mask_list = [ (torch.rand_like(random_indices.float()) > 0.05).long() for random_indices in random_indices_list ] encoding_indices_list = [ mask * encoding_indices + (1 - mask) * random_indices for encoding_indices, random_indices, mask in zip( encoding_indices_list, random_indices_list, mask_list) ] encodings = [ cuda_variable( torch.zeros(encoding_indices.shape[0], self.codebook_size)) for encoding_indices in encoding_indices_list ] for encoding, encoding_indices in zip(encodings, encoding_indices_list): encoding.scatter_(1, encoding_indices, 1) # Quantize and unflatten quantized_list = [ torch.matmul(encoding, embedding) for encoding, embedding in zip(encodings, self.embeddings) ] quantized = torch.cat(quantized_list, dim=1).view(input_shape) quantization_loss = self._loss(inputs, quantized) quantized_sg = inputs + (quantized - inputs).detach() # encoding_indices = torch.zeros_like(encoding_indices_list[0]) # for encoding_index in encoding_indices_list: # encoding_indices = encoding_indices * self.codebook_size + encoding_index # print(len(torch.unique(encoding_indices))) # encoding_indices = encoding_indices.view(input_shape[:-1]) encoding_indices_shape = list(input_shape[:-1]) + [-1] encoding_indices = torch.stack(encoding_indices_list, dim=-1).view(encoding_indices_shape) return quantized_sg, encoding_indices, quantization_loss
def forward(self, inputs, **kwargs): loss = cuda_variable(torch.zeros_like(inputs)).sum(dim=-1) quantized_sg = inputs encoding_indices = None return quantized_sg, encoding_indices, loss
def forward(self, q): """ :param q: (batch_size * num_heads, len_q_tgt, d) :return: """ sz_b_times_n_head, len_q, d_q = q.size() assert sz_b_times_n_head % self.num_heads == 0 sz_b = sz_b_times_n_head // self.num_heads batch_size = sz_b_times_n_head ################################ # Causal e1 = self.e1.unsqueeze(0).repeat(sz_b, 1, 1) e1 = e1.view(sz_b * self.num_heads, self.seq_len_src, d_q) rel_attn_1 = torch.einsum('bld,bmd->blm', (q, e1)) # tgt * src -> src * tgt rel_attn_1 = rel_attn_1.view(batch_size, self.seq_len_src, self.seq_len_tgt) # one column padding on dim 2 rel_attn_1 = torch.cat([ cuda_variable(torch.ones(batch_size, self.seq_len_src, 1) * -100), rel_attn_1, ], dim=2) # fill in with lines (ensure view can be done) bottom_extension = self.seq_len_tgt - self.seq_len_src if bottom_extension != 0: rel_attn_1 = torch.cat([ rel_attn_1, cuda_variable( torch.ones(batch_size, bottom_extension, self.seq_len_tgt + 1) * -100), ], dim=1) # skewing rel_attn_1 = rel_attn_1.view(batch_size, -1, self.seq_len_src) # need to remove first line here rel_attn_1 = rel_attn_1[:, 1:] rel_attn_1 = rel_attn_1[:, :self.seq_len_tgt, :] ################################ ################################ # Anticausal e2 = self.e2.unsqueeze(0).repeat(sz_b, 1, 1) e2 = e2.view(sz_b * self.num_heads, self.seq_len_src, d_q) rel_attn_2 = torch.einsum('bld,bmd->blm', (q, e2)) batch_size = rel_attn_2.size(0) # tgt * src -> src * tgt rel_attn_2 = rel_attn_2.view(batch_size, self.seq_len_src, self.seq_len_tgt) # one column padding on dim 2 rel_attn_2 = torch.cat([ rel_attn_2, cuda_variable(torch.ones(batch_size, self.seq_len_src, 1) * -100), ], dim=2) # fill in with lines (ensure view can be done) bottom_extension = self.seq_len_tgt - self.seq_len_src if bottom_extension != 0: rel_attn_2 = torch.cat([ rel_attn_2, cuda_variable( torch.ones(batch_size, bottom_extension, self.seq_len_tgt + 1) * -100), ], dim=1) # SKEWWWIIIIING (tgt + 1) * (tgt + 1) -> x * tgt rel_attn_2 = rel_attn_2.view(batch_size, -1, self.seq_len_src) rel_attn_2 = rel_attn_2[:, :self.seq_len_tgt, :] ################################ # mask causal and anticausal # Using ones_like is faster than cuda_variable(ones(...)) masks_down = torch.triu(torch.ones_like( rel_attn_1[0, :self.seq_len_src, :self.seq_len_src]).byte(), diagonal=0).unsqueeze(0).repeat( sz_b_times_n_head, 1, 1).flip(1).flip(2).type(torch.bool) if self.subsampling_ratio != 1: masks_down = torch.repeat_interleave(masks_down, self.subsampling_ratio, dim=1) masks_up = torch.triu(torch.ones_like( rel_attn_1[0, :self.seq_len_src, :self.seq_len_src]).byte(), diagonal=1).unsqueeze(0).repeat( sz_b_times_n_head, 1, 1).type(torch.bool) if self.subsampling_ratio != 1: masks_up = torch.repeat_interleave(masks_up, self.subsampling_ratio, dim=1) rel_attn_1 = rel_attn_1.masked_fill(masks_up, 0) rel_attn_2 = rel_attn_2.masked_fill(masks_down, 0) rel_attn = rel_attn_1 + rel_attn_2 return rel_attn
dim=1) masks_up = torch.triu(torch.ones_like( rel_attn_1[0, :self.seq_len_src, :self.seq_len_src]).byte(), diagonal=1).unsqueeze(0).repeat( sz_b_times_n_head, 1, 1).type(torch.bool) if self.subsampling_ratio != 1: masks_up = torch.repeat_interleave(masks_up, self.subsampling_ratio, dim=1) rel_attn_1 = rel_attn_1.masked_fill(masks_up, 0) rel_attn_2 = rel_attn_2.masked_fill(masks_down, 0) rel_attn = rel_attn_1 + rel_attn_2 return rel_attn if __name__ == '__main__': batch_size = 1 head_dim = 2 num_heads = 1 seq_len_src = 6 seq_len_tgt = 6 aa = SubsampledRelativeAttention(head_dim, num_heads, seq_len_src, seq_len_tgt) aa.to('cuda') q = cuda_variable( torch.ones((batch_size * num_heads, seq_len_tgt, head_dim))) ret = aa.forward(q) exit()
def generate( self, num_tokens, decoder, temperature=1., num_generated_codes=1, num_decodings_per_generating_code=1, ): self.eval() decoder.eval() with torch.no_grad(): # init x = cuda_variable(torch.zeros(1, num_tokens, self.num_channels)).long() x = x.repeat(num_generated_codes, 1, 1) assert num_tokens % self.num_channels == 0 # num_tokens is the number of the sequence to be generated # while self.num_tokens is the number of tokens of the input of the model assert num_tokens >= self.num_tokens num_events = num_tokens // self.num_channels for event_index in range(num_events): for channel_index in range(self.num_channels): # removes channel dim x_input = x[:, :, 0] if event_index >= self.num_tokens: x_input = x_input[:, event_index - self.num_tokens + 1:event_index + 1] event_offset = event_index - self.num_tokens + 1 else: x_input = x_input[:, :self.num_tokens] event_offset = 0 weights_per_voice = self.forward( x_input)['weights_per_category'] weights = weights_per_voice[channel_index] probs = torch.softmax(weights[:, event_index - event_offset, :], dim=1) p = to_numpy(probs) # temperature ?! p = np.exp(np.log(p + 1e-20) * temperature) p = p / p.sum(axis=1, keepdims=True) for batch_index in range(num_generated_codes): new_pitch_index = np.random.choice(np.arange( self.num_tokens_per_channel[channel_index]), p=p[batch_index]) x[batch_index, event_index, channel_index] = int(new_pitch_index) source = x[:, :, 0] scores = decoder.generate_from_code_long( encoding_indices=source, temperature=temperature, num_decodings=num_decodings_per_generating_code) # save scores in model_dir timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if not os.path.exists(f'{self.model_dir}/generations'): os.mkdir(f'{self.model_dir}/generations') for k, score in enumerate(scores): score.write('xml', f'{self.model_dir}/generations/{timestamp}_{k}.xml') return scores