Ejemplo n.º 1
0
    def test_interpretability(self, batch_size, attr_type):
        """
        Tests the interpretability of the latent space for a partcular attribute
        :param batch_size: int, number of datapoints in mini-batch
        :param attr_type: str, attribute type
        :return: tuple(int, float): index of dimension with highest mutual info, interpretability score
        """
        (_, gen_val, gen_test) = self.dataset.data_loaders(
            batch_size=batch_size,
            split=(0.01, 0.01)
        )

        # compute latent vectors and attribute values
        z_all = []
        attr_all = []
        for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(gen_test)):
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1)
            # convert input to torch Variables
            score_tensor, metadata_tensor = (
                to_cuda_variable_long(score_tensor),
                to_cuda_variable_long(metadata_tensor)
            )
            # compute encoder forward pass
            z_dist = self.model.encoder(score_tensor)
            # sample from distribution
            z_tilde = z_dist.rsample()

            # compute attributes
            if attr_type == 'rhy_complexity':
                attr = self.dataset.get_rhy_complexity(score_tensor)
            elif attr_type == 'num_notes':
                attr = self.dataset.get_note_density_in_measure(score_tensor)
            elif attr_type == 'note_range':
                attr = self.dataset.get_pitch_range_in_measure(score_tensor)
            z_all.append(to_numpy(z_tilde.cpu()))
            attr_all.append(to_numpy(attr.cpu()))
        z_all = np.concatenate(z_all)
        attr_all = np.concatenate(attr_all)

        # compute mutual information
        mutual_info = np.zeros(self.z_dim)
        for i in tqdm(range(self.z_dim)):
            mutual_info[i] = mutual_info_score(z_all[:, i], attr_all)
        dim = np.argmax(mutual_info)
        max_mutual_info = np.max(mutual_info)

        reg = LinearRegression().fit(z_all[:, dim:dim+1], attr_all)
        score = reg.score(z_all[:, dim:dim+1], attr_all)
        return dim, score
Ejemplo n.º 2
0
 def get_contour(self, measure_tensor):
     """
     Returns the direction of the melodic contour of a batch of measures
     :param measure_tensor: torch Variable
             (batch_size, measure_seq_len)
     :return torch.Variable
             (batch_size)
     """
     batch_size, measure_seq_len = measure_tensor.size()
     index2note = self.index2note_dicts
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     contour = to_cuda_variable(torch.zeros(batch_size))
     for i in range(batch_size):
         midi_notes = []
         for j in range(measure_seq_len):
             index = measure_tensor[i][j].item()
             if index not in (slur_index, rest_index, none_index,
                              start_index, end_index):
                 midi_note = torch.Tensor(
                     [music21.pitch.Pitch(index2note[index]).midi])
                 midi_notes.append(midi_note)
         if len(midi_notes) == 0 or len(midi_notes) == 1:
             contour[i] = 0
         else:
             midi_notes = torch.cat(midi_notes, 0)
             a = midi_notes[1:] - midi_notes[:-1]
             b = torch.sum(a)
             contour[i] = to_cuda_variable_long(b)
     return contour / 26
Ejemplo n.º 3
0
 def get_pitch_range_in_measure(self, measure_tensor):
     """
     Returns the note range of each measure of the input normalized by the range
     the dataset
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable containing float tensor ,
             (batch_size)
     """
     batch_size, measure_seq_len = measure_tensor.size()
     index2note = self.index2note_dicts
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     nrange = to_cuda_variable(torch.zeros(batch_size))
     for i in range(batch_size):
         midi_notes = []
         for j in range(measure_seq_len):
             index = measure_tensor[i][j].item()
             if index not in (slur_index, rest_index, none_index,
                              start_index, end_index):
                 midi_note = torch.Tensor(
                     [music21.pitch.Pitch(index2note[index]).midi])
                 midi_notes.append(midi_note)
         if len(midi_notes) == 0 or len(midi_notes) == 1:
             nrange[i] = 0
         else:
             midi_notes = torch.cat(midi_notes, 0)
             midi_notes = to_cuda_variable_long(midi_notes)
             nrange[i] = torch.max(midi_notes) - torch.min(midi_notes)
     return nrange / 26
Ejemplo n.º 4
0
    def loss_and_acc_test(self, data_loader):
        """
        Computes loss and accuracy for test data
        :param data_loader: torch data loader object
        :return: (float, float)
        """
        mean_loss = 0
        mean_accuracy = 0

        for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(data_loader)):
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1)
            # convert input to torch Variables
            score_tensor, metadata_tensor = (
                to_cuda_variable_long(score_tensor),
                to_cuda_variable_long(metadata_tensor)
            )
            # compute forward pass
            weights, samples, _, _, _, _ = self.model(
                measure_score_tensor=score_tensor,
                measure_metadata_tensor=metadata_tensor,
                train=False
            )

            # compute loss
            recons_loss = MeasureVAETrainer.mean_crossentropy_loss(
                weights=weights,
                targets=score_tensor
            )
            loss = recons_loss
            # compute mean loss and accuracy
            mean_loss += to_numpy(loss.mean())
            accuracy = MeasureVAETrainer.mean_accuracy(
                weights=weights,
                targets=score_tensor
            )
            mean_accuracy += to_numpy(accuracy)
        mean_loss /= len(data_loader)
        mean_accuracy /= len(data_loader)
        return (
            mean_loss,
            mean_accuracy
        )
Ejemplo n.º 5
0
    def _plot_data_attr_dist(self, gen_test, dim1, dim2, reg_type):
        z_all = []
        attr_all = []
        for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(gen_test)):
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1)
            # convert input to torch Variables
            score_tensor, metadata_tensor = (
                to_cuda_variable_long(score_tensor),
                to_cuda_variable_long(metadata_tensor)
            )
            # compute encoder forward pass
            z_dist = self.model.encoder(score_tensor)
            # sample from distribution
            z_tilde = z_dist.rsample()

            # compute attributes
            if reg_type == 'rhy_complexity':
                attr = self.dataset.get_rhy_complexity(score_tensor)
            elif reg_type == 'num_notes':
                attr = self.dataset.get_note_density_in_measure(score_tensor)
            elif reg_type == 'note_range':
                attr = self.dataset.get_pitch_range_in_measure(score_tensor)
            z_all.append(z_tilde)
            attr_all.append(attr)
        z_all = to_numpy(torch.cat(z_all, 0))
        attr_all = to_numpy(torch.cat(attr_all, 0))
        if self.trainer_config == '':
            reg_str = '[no_reg]'
        else:
            reg_str = self.trainer_config
        filename = self.dir_path + '/plots/' + reg_str + 'data_dist_' + reg_type + '_[' \
                   + str(dim1) + ',' + str(dim2) + '].png'
        self.plot_dim(z_all, attr_all, filename, dim1=dim1, dim2=dim2, xlim=6, ylim=6)
Ejemplo n.º 6
0
 def get_interval_entropy(self, measure_tensor):
     """
     Returns the normalized interval entropy of a batch of measures
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable,
             (batch_size)
     """
     batch_size, measure_seq_len = measure_tensor.size()
     index2note = self.index2note_dicts
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     has_note = False
     nrange = to_cuda_variable(torch.zeros(batch_size))
     for i in range(batch_size):
         midi_notes = []
         for j in range(measure_seq_len):
             index = measure_tensor[i][j].item()
             if index not in (slur_index, rest_index, none_index,
                              start_index, end_index):
                 midi_note = torch.Tensor(
                     [music21.pitch.Pitch(index2note[index]).midi])
                 midi_notes.append(midi_note)
         if len(midi_notes) == 0 or len(midi_notes) == 1:
             nrange[i] = 0
         else:
             midi_notes = torch.cat(midi_notes, 0)
             midi_notes = to_cuda_variable_long(midi_notes)
             # compute intervals
             midi_notes = torch.abs(midi_notes[1:] - midi_notes[:-1])
             midi_notes = midi_notes % 12
             probs = to_cuda_variable(torch.zeros([12, 1]))
             for k in range(len(midi_notes)):
                 probs[midi_notes[k]] += 1
             # compute entropy of this interval vector
             b = nn.functional.softmax(probs, dim=0) * \
                 nn.functional.log_softmax(probs, dim=0)
             b = -1.0 * b.sum()
             nrange[i] = b
     return nrange
Ejemplo n.º 7
0
    def plot_transposition_points(self, plt_type='pca'):
        """
        Plots a t-SNE plot for data-points comprising of transposed measures
        :param plt_type: str, 'tsne' or 'pca'
        :return:
        """
        filepaths = self.dataset.valid_filepaths
        idx = random.randint(0, len(filepaths))
        original_score = get_music21_score_from_path(filepaths[idx])
        possible_transpositions = self.dataset.all_transposition_intervals(original_score)
        z_all = []
        n_all = []
        n = 0
        for trans_int in possible_transpositions:
            score_tensor = self.dataset.get_transposed_tensor(
                original_score,
                trans_int
            )
            score_tensor = self.dataset.split_tensor_to_bars(score_tensor)
            score_tensor = to_cuda_variable_long(score_tensor)
            z_dist = self.model.encoder(score_tensor)
            z_tilde = z_dist.loc
            z_all.append(z_tilde)
            t = np.arange(0, z_tilde.size(0))
            n_all.append(torch.from_numpy(t))
            # n_all.append(torch.ones(z_tilde.size(0)) * n)
            n += 1
        print(n)
        z_all = torch.cat(z_all, 0)
        n_all = torch.cat(n_all, 0)
        z_all = to_numpy(z_all)
        n_all = to_numpy(n_all)

        filename = self.dir_path + '/plots/' + plt_type + '_transposition_measure_vae.png'
        if plt_type == 'pca':
            self.plot_pca(z_all, n_all, filename)
        elif plt_type == 'tsne':
            self.plot_tsne(z_all, n_all, filename)
        else:
            raise ValueError('Invalid plot type')
Ejemplo n.º 8
0
    def plot_attribute_dist(self, attribute='num_notes', plt_type='pca'):
        """
        Plots the distribution of a particular attribute in the latent space
        :param attribute: str,
                num_notes, note_range, rhy_entropy, beat_strength, rhy_complexity
        :param plt_type: str, 'tsne' or 'pca'
        :return:
        """
        (_, _, gen_test) = self.dataset.data_loaders(
            batch_size=64,  # TODO: remove this hard coding
            split=(0.01, 0.01)
        )
        z_all = []
        n_all = []
        num_samples = 5
        for sample_id, (score_tensor, _) in tqdm(enumerate(gen_test)):
            # convert input to torch Variables
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
            score_tensor = to_cuda_variable_long(score_tensor)
            # compute encoder forward pass
            z_dist = self.model.encoder(score_tensor)
            z_tilde = z_dist.loc
            z_all.append(z_tilde)
            if attribute == 'num_notes':
                attr = self.dataset.get_note_density_in_measure(score_tensor)
            elif attribute == 'note_range':
                attr = self.dataset.get_pitch_range_in_measure(score_tensor)
            elif attribute == 'rhy_entropy':
                attr = self.dataset.get_rhythmic_entropy(score_tensor)
            elif attribute == 'beat_strength':
                attr = self.dataset.get_beat_strength(score_tensor)
            elif attribute == 'rhy_complexity':
                attr = self.dataset.get_rhy_complexity(score_tensor)
            else:
                raise ValueError('Invalid attribute type')
            for i in range(attr.size(0)):
                tensor_score = score_tensor[i, :]
                start_idx = self.dataset.note2index_dicts[START_SYMBOL]
                end_idx = self.dataset.note2index_dicts[END_SYMBOL]
                if tensor_score[0] == start_idx:
                    attr[i] = -0.1
                elif tensor_score[0] == end_idx:
                    attr[i] = -0.2
            n_all.append(attr)
            if sample_id == num_samples:
                break
        z_all = torch.cat(z_all, 0)
        n_all = torch.cat(n_all, 0)
        z_all = to_numpy(z_all)
        n_all = to_numpy(n_all)

        filename = self.dir_path + '/plots/' + plt_type + '_' + attribute + '_' + \
                   str(num_samples) + '_measure_vae.png'
        if plt_type == 'pca':
            self.plot_pca(z_all, n_all, filename)
        elif plt_type == 'tsne':
            self.plot_tsne(z_all, n_all, filename)
        elif plt_type == 'dim':
            self.plot_dim(z_all, n_all, filename)
        else:
            raise ValueError('Invalid plot type')