def test_interpretability(self, batch_size, attr_type): """ Tests the interpretability of the latent space for a partcular attribute :param batch_size: int, number of datapoints in mini-batch :param attr_type: str, attribute type :return: tuple(int, float): index of dimension with highest mutual info, interpretability score """ (_, gen_val, gen_test) = self.dataset.data_loaders( batch_size=batch_size, split=(0.01, 0.01) ) # compute latent vectors and attribute values z_all = [] attr_all = [] for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(gen_test)): if isinstance(self.dataset, FolkNBarDataset): batch_size = score_tensor.size(0) score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1) score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1) metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1) metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1) # convert input to torch Variables score_tensor, metadata_tensor = ( to_cuda_variable_long(score_tensor), to_cuda_variable_long(metadata_tensor) ) # compute encoder forward pass z_dist = self.model.encoder(score_tensor) # sample from distribution z_tilde = z_dist.rsample() # compute attributes if attr_type == 'rhy_complexity': attr = self.dataset.get_rhy_complexity(score_tensor) elif attr_type == 'num_notes': attr = self.dataset.get_note_density_in_measure(score_tensor) elif attr_type == 'note_range': attr = self.dataset.get_pitch_range_in_measure(score_tensor) z_all.append(to_numpy(z_tilde.cpu())) attr_all.append(to_numpy(attr.cpu())) z_all = np.concatenate(z_all) attr_all = np.concatenate(attr_all) # compute mutual information mutual_info = np.zeros(self.z_dim) for i in tqdm(range(self.z_dim)): mutual_info[i] = mutual_info_score(z_all[:, i], attr_all) dim = np.argmax(mutual_info) max_mutual_info = np.max(mutual_info) reg = LinearRegression().fit(z_all[:, dim:dim+1], attr_all) score = reg.score(z_all[:, dim:dim+1], attr_all) return dim, score
def get_contour(self, measure_tensor): """ Returns the direction of the melodic contour of a batch of measures :param measure_tensor: torch Variable (batch_size, measure_seq_len) :return torch.Variable (batch_size) """ batch_size, measure_seq_len = measure_tensor.size() index2note = self.index2note_dicts slur_index = self.note2index_dicts[SLUR_SYMBOL] rest_index = self.note2index_dicts['rest'] none_index = self.note2index_dicts[None] start_index = self.note2index_dicts[START_SYMBOL] end_index = self.note2index_dicts[END_SYMBOL] contour = to_cuda_variable(torch.zeros(batch_size)) for i in range(batch_size): midi_notes = [] for j in range(measure_seq_len): index = measure_tensor[i][j].item() if index not in (slur_index, rest_index, none_index, start_index, end_index): midi_note = torch.Tensor( [music21.pitch.Pitch(index2note[index]).midi]) midi_notes.append(midi_note) if len(midi_notes) == 0 or len(midi_notes) == 1: contour[i] = 0 else: midi_notes = torch.cat(midi_notes, 0) a = midi_notes[1:] - midi_notes[:-1] b = torch.sum(a) contour[i] = to_cuda_variable_long(b) return contour / 26
def get_pitch_range_in_measure(self, measure_tensor): """ Returns the note range of each measure of the input normalized by the range the dataset :param measure_tensor: torch Variable, (batch_size, measure_seq_len) :return: torch Variable containing float tensor , (batch_size) """ batch_size, measure_seq_len = measure_tensor.size() index2note = self.index2note_dicts slur_index = self.note2index_dicts[SLUR_SYMBOL] rest_index = self.note2index_dicts['rest'] none_index = self.note2index_dicts[None] start_index = self.note2index_dicts[START_SYMBOL] end_index = self.note2index_dicts[END_SYMBOL] nrange = to_cuda_variable(torch.zeros(batch_size)) for i in range(batch_size): midi_notes = [] for j in range(measure_seq_len): index = measure_tensor[i][j].item() if index not in (slur_index, rest_index, none_index, start_index, end_index): midi_note = torch.Tensor( [music21.pitch.Pitch(index2note[index]).midi]) midi_notes.append(midi_note) if len(midi_notes) == 0 or len(midi_notes) == 1: nrange[i] = 0 else: midi_notes = torch.cat(midi_notes, 0) midi_notes = to_cuda_variable_long(midi_notes) nrange[i] = torch.max(midi_notes) - torch.min(midi_notes) return nrange / 26
def loss_and_acc_test(self, data_loader): """ Computes loss and accuracy for test data :param data_loader: torch data loader object :return: (float, float) """ mean_loss = 0 mean_accuracy = 0 for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(data_loader)): if isinstance(self.dataset, FolkNBarDataset): batch_size = score_tensor.size(0) score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1) score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1) metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1) metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1) # convert input to torch Variables score_tensor, metadata_tensor = ( to_cuda_variable_long(score_tensor), to_cuda_variable_long(metadata_tensor) ) # compute forward pass weights, samples, _, _, _, _ = self.model( measure_score_tensor=score_tensor, measure_metadata_tensor=metadata_tensor, train=False ) # compute loss recons_loss = MeasureVAETrainer.mean_crossentropy_loss( weights=weights, targets=score_tensor ) loss = recons_loss # compute mean loss and accuracy mean_loss += to_numpy(loss.mean()) accuracy = MeasureVAETrainer.mean_accuracy( weights=weights, targets=score_tensor ) mean_accuracy += to_numpy(accuracy) mean_loss /= len(data_loader) mean_accuracy /= len(data_loader) return ( mean_loss, mean_accuracy )
def _plot_data_attr_dist(self, gen_test, dim1, dim2, reg_type): z_all = [] attr_all = [] for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(gen_test)): if isinstance(self.dataset, FolkNBarDataset): batch_size = score_tensor.size(0) score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1) score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1) metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1) metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1) # convert input to torch Variables score_tensor, metadata_tensor = ( to_cuda_variable_long(score_tensor), to_cuda_variable_long(metadata_tensor) ) # compute encoder forward pass z_dist = self.model.encoder(score_tensor) # sample from distribution z_tilde = z_dist.rsample() # compute attributes if reg_type == 'rhy_complexity': attr = self.dataset.get_rhy_complexity(score_tensor) elif reg_type == 'num_notes': attr = self.dataset.get_note_density_in_measure(score_tensor) elif reg_type == 'note_range': attr = self.dataset.get_pitch_range_in_measure(score_tensor) z_all.append(z_tilde) attr_all.append(attr) z_all = to_numpy(torch.cat(z_all, 0)) attr_all = to_numpy(torch.cat(attr_all, 0)) if self.trainer_config == '': reg_str = '[no_reg]' else: reg_str = self.trainer_config filename = self.dir_path + '/plots/' + reg_str + 'data_dist_' + reg_type + '_[' \ + str(dim1) + ',' + str(dim2) + '].png' self.plot_dim(z_all, attr_all, filename, dim1=dim1, dim2=dim2, xlim=6, ylim=6)
def get_interval_entropy(self, measure_tensor): """ Returns the normalized interval entropy of a batch of measures :param measure_tensor: torch Variable, (batch_size, measure_seq_len) :return: torch Variable, (batch_size) """ batch_size, measure_seq_len = measure_tensor.size() index2note = self.index2note_dicts slur_index = self.note2index_dicts[SLUR_SYMBOL] rest_index = self.note2index_dicts['rest'] none_index = self.note2index_dicts[None] start_index = self.note2index_dicts[START_SYMBOL] end_index = self.note2index_dicts[END_SYMBOL] has_note = False nrange = to_cuda_variable(torch.zeros(batch_size)) for i in range(batch_size): midi_notes = [] for j in range(measure_seq_len): index = measure_tensor[i][j].item() if index not in (slur_index, rest_index, none_index, start_index, end_index): midi_note = torch.Tensor( [music21.pitch.Pitch(index2note[index]).midi]) midi_notes.append(midi_note) if len(midi_notes) == 0 or len(midi_notes) == 1: nrange[i] = 0 else: midi_notes = torch.cat(midi_notes, 0) midi_notes = to_cuda_variable_long(midi_notes) # compute intervals midi_notes = torch.abs(midi_notes[1:] - midi_notes[:-1]) midi_notes = midi_notes % 12 probs = to_cuda_variable(torch.zeros([12, 1])) for k in range(len(midi_notes)): probs[midi_notes[k]] += 1 # compute entropy of this interval vector b = nn.functional.softmax(probs, dim=0) * \ nn.functional.log_softmax(probs, dim=0) b = -1.0 * b.sum() nrange[i] = b return nrange
def plot_transposition_points(self, plt_type='pca'): """ Plots a t-SNE plot for data-points comprising of transposed measures :param plt_type: str, 'tsne' or 'pca' :return: """ filepaths = self.dataset.valid_filepaths idx = random.randint(0, len(filepaths)) original_score = get_music21_score_from_path(filepaths[idx]) possible_transpositions = self.dataset.all_transposition_intervals(original_score) z_all = [] n_all = [] n = 0 for trans_int in possible_transpositions: score_tensor = self.dataset.get_transposed_tensor( original_score, trans_int ) score_tensor = self.dataset.split_tensor_to_bars(score_tensor) score_tensor = to_cuda_variable_long(score_tensor) z_dist = self.model.encoder(score_tensor) z_tilde = z_dist.loc z_all.append(z_tilde) t = np.arange(0, z_tilde.size(0)) n_all.append(torch.from_numpy(t)) # n_all.append(torch.ones(z_tilde.size(0)) * n) n += 1 print(n) z_all = torch.cat(z_all, 0) n_all = torch.cat(n_all, 0) z_all = to_numpy(z_all) n_all = to_numpy(n_all) filename = self.dir_path + '/plots/' + plt_type + '_transposition_measure_vae.png' if plt_type == 'pca': self.plot_pca(z_all, n_all, filename) elif plt_type == 'tsne': self.plot_tsne(z_all, n_all, filename) else: raise ValueError('Invalid plot type')
def plot_attribute_dist(self, attribute='num_notes', plt_type='pca'): """ Plots the distribution of a particular attribute in the latent space :param attribute: str, num_notes, note_range, rhy_entropy, beat_strength, rhy_complexity :param plt_type: str, 'tsne' or 'pca' :return: """ (_, _, gen_test) = self.dataset.data_loaders( batch_size=64, # TODO: remove this hard coding split=(0.01, 0.01) ) z_all = [] n_all = [] num_samples = 5 for sample_id, (score_tensor, _) in tqdm(enumerate(gen_test)): # convert input to torch Variables if isinstance(self.dataset, FolkNBarDataset): batch_size = score_tensor.size(0) score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1) score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1) score_tensor = to_cuda_variable_long(score_tensor) # compute encoder forward pass z_dist = self.model.encoder(score_tensor) z_tilde = z_dist.loc z_all.append(z_tilde) if attribute == 'num_notes': attr = self.dataset.get_note_density_in_measure(score_tensor) elif attribute == 'note_range': attr = self.dataset.get_pitch_range_in_measure(score_tensor) elif attribute == 'rhy_entropy': attr = self.dataset.get_rhythmic_entropy(score_tensor) elif attribute == 'beat_strength': attr = self.dataset.get_beat_strength(score_tensor) elif attribute == 'rhy_complexity': attr = self.dataset.get_rhy_complexity(score_tensor) else: raise ValueError('Invalid attribute type') for i in range(attr.size(0)): tensor_score = score_tensor[i, :] start_idx = self.dataset.note2index_dicts[START_SYMBOL] end_idx = self.dataset.note2index_dicts[END_SYMBOL] if tensor_score[0] == start_idx: attr[i] = -0.1 elif tensor_score[0] == end_idx: attr[i] = -0.2 n_all.append(attr) if sample_id == num_samples: break z_all = torch.cat(z_all, 0) n_all = torch.cat(n_all, 0) z_all = to_numpy(z_all) n_all = to_numpy(n_all) filename = self.dir_path + '/plots/' + plt_type + '_' + attribute + '_' + \ str(num_samples) + '_measure_vae.png' if plt_type == 'pca': self.plot_pca(z_all, n_all, filename) elif plt_type == 'tsne': self.plot_tsne(z_all, n_all, filename) elif plt_type == 'dim': self.plot_dim(z_all, n_all, filename) else: raise ValueError('Invalid plot type')