def postprocess(self, original, reconstruction): """ Inverse of preprocess :param x: (batch_size, num_events, num_channels) -> ? :return: """ if original is not None: tensor_score = torch.cat( [original.long(), reconstruction.cpu()], dim=1) else: tensor_score = torch.cat(reconstruction, dim=0) tensor_score = to_numpy(tensor_score) return tensor_score
def generate_from_code_long(self, encoding_indices, temperature, top_k=0, top_p=1., exclude_meta_symbols=False, num_decodings=1, code_index_start=None, code_index_end=None): """ Returns a list of music21 scores """ self.eval() size_encoding = encoding_indices.size(1) total_upscaling = int(np.prod(self.encoder.downscaler.downscale_factors)) num_tokens_indices = self.data_processor.num_tokens // total_upscaling num_events_full_chorale = size_encoding * total_upscaling // self.data_processor.num_channels num_events_before_start = code_index_start * total_upscaling // self.num_channels num_events_before_end = code_index_end * total_upscaling // self.num_channels batch_size = num_decodings * encoding_indices.size(0) if code_index_start is None: code_index_start = 0 if code_index_end is None: code_index_end = size_encoding with torch.no_grad(): chorale = self.init_generation_chorale(num_events=num_events_full_chorale, start_index=num_events_before_start) # Duplicate along batch dimension chorale = chorale.repeat(batch_size, 1, 1) encoding_indices = encoding_indices.repeat_interleave(num_decodings, dim=0) for code_index in range(code_index_start, code_index_end): for relative_event in range(self.num_events_per_code): for channel_index in range(self.data_processor.num_channels): t_begin, t_end, t_relative = self.compute_start_end_times( code_index, num_blocks=size_encoding, num_blocks_model=num_tokens_indices ) input_encoding_indices = encoding_indices[:, t_begin:t_end] input_chorale = chorale[:, t_begin * self.num_events_per_code: t_end * self.num_events_per_code, :] weights_per_voice = self.forward(input_encoding_indices, input_chorale)['weights_per_category'] # Keep only the last token predictions of the first batch item (batch size 1), apply a # temperature coefficient and filter weights = weights_per_voice[channel_index] logits = weights[:, t_relative * self.num_events_per_code + relative_event, :] / temperature # Remove meta symbols # if exclude_meta_symbols: # for sym in [START_SYMBOL, END_SYMBOL, PAD_SYMBOL]: # sym_index = \ # self.dataloader_generator.dataset.note2index_dicts[ # channel_index][ # sym] # logits[:, sym_index] = -float("inf") # Top-p sampling filtered_logits = [] for logit in logits: filter_logit = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) filtered_logits.append(filter_logit) filtered_logits = torch.stack(filtered_logits, dim=0) # Sample from the filtered distribution p = to_numpy(torch.softmax(filtered_logits, dim=-1)) # weights = weights_per_voice[channel_index] # probs = torch.softmax( # weights[:, t_relative * self.num_events_per_code + relative_event, :], # dim=1) # p = to_numpy(probs) # # temperature ?! # p = np.exp(np.log(p + 1e-20) * temperature) # p = p / p.sum(axis=1, keepdims=True) for batch_index in range(batch_size): new_pitch_index = np.random.choice(np.arange( self.num_tokens_per_channel[channel_index] ), p=p[batch_index]) chorale[batch_index, code_index * self.num_events_per_code + relative_event, channel_index] = int( new_pitch_index) # slice chorale = chorale[:, num_events_before_start:num_events_before_end] tensor_scores = to_numpy(chorale) # Write scores scores = [] for k, tensor_score in enumerate(tensor_scores): scores.append(self.dataloader_generator.to_score(tensor_score)) return scores
def generate(self, temperature, batch_size=1, top_k=0, top_p=1., seed_set=None, exclude_meta_symbols=False, plot_attentions=False, code_juxtaposition=False): self.eval() (generator_train, generator_val, _) = self.dataloader_generator.dataloaders( batch_size=1, shuffle_val=True ) with torch.no_grad(): if code_juxtaposition: # Use the codes of a chorale for the first half, and the codes from another chorale for the last half if seed_set == 'val': tensor_dict_beginning = next(iter(generator_val)) tensor_dict_end = next(iter(generator_val)) elif seed_set == 'train': tensor_dict_beginning = next(iter(generator_train)) tensor_dict_end = next(iter(generator_train)) else: raise Exception('Need to indicate seeds dataset') num_events_chorale_half = tensor_dict_beginning['x'].shape[1] // 2 x_beg = tensor_dict_beginning['x'][:, :num_events_chorale_half] x_end = tensor_dict_end['x'][:, num_events_chorale_half:] x_original_single = torch.cat([x_beg, x_end], dim=1) x_original = x_original_single.repeat(batch_size, 1, 1) else: if seed_set == 'val': tensor_dict = next(iter(generator_val)) elif seed_set == 'train': tensor_dict = next(iter(generator_train)) else: raise Exception('Need to indicate seeds dataset') x_original_single = tensor_dict['x'] x_original = x_original_single.repeat(batch_size, 1, 1) # compute downscaled version zs, encoding_indices, _ = self.encoder(x_original) if encoding_indices is None: # if no quantization is used, directly use the zs encoding_indices = zs else: encoding_indices = self.encoder.merge_codes(encoding_indices) x = self.init_generation(num_events=self.data_processor.num_events) # Duplicate along batch dimension x = x.repeat(batch_size, 1, 1) attentions_decoder_list = [] attentions_encoder_list = [] attentions_cross_list = [] for event_index in range(self.data_processor.num_events): for channel_index in range(self.num_channels): forward_pass = self.forward(encoding_indices, x) weights_per_voice = forward_pass['weights_per_category'] weights = weights_per_voice[channel_index] # Keep only the last token predictions of the first batch item (batch size 1), apply a # temperature coefficient and filter logits = weights[:, event_index, :] / temperature # Remove meta symbols if exclude_meta_symbols: for sym in [START_SYMBOL, END_SYMBOL, PAD_SYMBOL]: sym_index = \ self.dataloader_generator.dataset.note2index_dicts[channel_index][ sym] logits[:, sym_index] = -float("inf") # Top-p sampling filtered_logits = [] for logit in logits: filter_logit = top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) filtered_logits.append(filter_logit) filtered_logits = torch.stack(filtered_logits, dim=0) # Sample from the filtered distribution p = to_numpy(torch.softmax(filtered_logits, dim=-1)) # update generated sequence for batch_index in range(batch_size): new_pitch_index = np.random.choice(np.arange( self.num_tokens_per_channel[channel_index] ), p=p[batch_index]) x[batch_index, event_index, channel_index] = int(new_pitch_index) # store attentions if plot_attentions: layer = 2 event_index_encoder = ( event_index * self.num_channels) // self.total_upscaling attentions_encoder = forward_pass['attentions_encoder'] # list of dicts with key 'a_self_encoder' attentions_decoder = forward_pass['attentions_decoder'] # list of dicts with keys 'a_self_decoder' and 'a_cross' # get attentions at corresponding event attn_encoder = attentions_encoder[layer]['a_self_encoder'][:, :, event_index_encoder, :] attn_decoder = attentions_decoder[layer]['a_self_decoder'][:, :, event_index * self.num_channels + channel_index, :] attn_cross = attentions_decoder[layer]['a_cross'][:, :, event_index * self.num_channels + channel_index, :] attentions_encoder_list.append(attn_encoder) attentions_decoder_list.append(attn_decoder) attentions_cross_list.append(attn_cross) # Compute codes for generations x_re_encode = torch.cat([ cuda_variable(x_original_single.long()), x ], dim=0) _, recoding_, _ = self.encoder(x_re_encode) if recoding_ is not None: recoding_ = recoding_.detach().cpu().numpy() recoding = self.encoder.merge_codes(recoding_) else: recoding = None # to score original_and_reconstruction = self.data_processor.postprocess(original=x_original.long(), reconstruction=x.cpu()) ############################### # Saving timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if code_juxtaposition: save_dir = f'{self.model_dir}/juxtapositions' else: save_dir = f'{self.model_dir}/generations' if not os.path.exists(save_dir): os.mkdir(save_dir) # Write code sequence if recoding is not None: with open(f'{save_dir}/{timestamp}.txt', 'w') as ff: for batch_ind in range(len(recoding)): aa = recoding[batch_ind] ff.write(' , '.join(map(str, list(aa)))) ff.write('\n') # Write scores scores = [] for k, tensor_score in enumerate(original_and_reconstruction): path_no_extension = f'{save_dir}/{timestamp}_{k}' scores.append(self.dataloader_generator.write(tensor_score, path_no_extension)) print(f'Saved in {save_dir}/{timestamp}') ############################### if plot_attentions: self.plot_attention(attentions_cross_list, timestamp=timestamp, name='attns_cross') self.plot_attention(attentions_encoder_list, timestamp=timestamp, name='self_attns_encoder') self.plot_attention(attentions_decoder_list, timestamp=timestamp, name='self_attns_decoder') return scores
def generate( self, num_tokens, decoder, temperature=1., num_generated_codes=1, num_decodings_per_generating_code=1, ): self.eval() decoder.eval() with torch.no_grad(): # init x = cuda_variable(torch.zeros(1, num_tokens, self.num_channels)).long() x = x.repeat(num_generated_codes, 1, 1) assert num_tokens % self.num_channels == 0 # num_tokens is the number of the sequence to be generated # while self.num_tokens is the number of tokens of the input of the model assert num_tokens >= self.num_tokens num_events = num_tokens // self.num_channels for event_index in range(num_events): for channel_index in range(self.num_channels): # removes channel dim x_input = x[:, :, 0] if event_index >= self.num_tokens: x_input = x_input[:, event_index - self.num_tokens + 1:event_index + 1] event_offset = event_index - self.num_tokens + 1 else: x_input = x_input[:, :self.num_tokens] event_offset = 0 weights_per_voice = self.forward( x_input)['weights_per_category'] weights = weights_per_voice[channel_index] probs = torch.softmax(weights[:, event_index - event_offset, :], dim=1) p = to_numpy(probs) # temperature ?! p = np.exp(np.log(p + 1e-20) * temperature) p = p / p.sum(axis=1, keepdims=True) for batch_index in range(num_generated_codes): new_pitch_index = np.random.choice(np.arange( self.num_tokens_per_channel[channel_index]), p=p[batch_index]) x[batch_index, event_index, channel_index] = int(new_pitch_index) source = x[:, :, 0] scores = decoder.generate_from_code_long( encoding_indices=source, temperature=temperature, num_decodings=num_decodings_per_generating_code) # save scores in model_dir timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') if not os.path.exists(f'{self.model_dir}/generations'): os.mkdir(f'{self.model_dir}/generations') for k, score in enumerate(scores): score.write('xml', f'{self.model_dir}/generations/{timestamp}_{k}.xml') return scores