예제 #1
0
    def plot_latent_surface(self, attr_str, dim1=0, dim2=1, grid_res=0.1):
        # create the dataspace
        x1 = torch.arange(-5., 5., grid_res)
        x2 = torch.arange(-5., 5., grid_res)
        z1, z2 = torch.meshgrid([x1, x2])
        num_points = z1.size(0) * z1.size(1)
        z = torch.randn(1, self.model.z_dim)
        z = z.repeat(num_points, 1)
        z[:, dim1] = z1.contiguous().view(1, -1)
        z[:, dim2] = z2.contiguous().view(1, -1)
        z = to_cuda_variable(z)

        mini_batch_size = 500
        num_mini_batches = num_points // mini_batch_size
        attr_labels_all = []
        for i in tqdm(range(num_mini_batches)):
            z_batch = z[i * mini_batch_size:(i + 1) * mini_batch_size, :]
            outputs = torch.sigmoid(self.model.decode(z_batch))
            labels = self.compute_mnist_morpho_labels(outputs, attr_str)
            attr_labels_all.append(torch.from_numpy(labels))
        attr_labels_all = to_numpy(torch.cat(attr_labels_all, 0))
        z = to_numpy(z)[:num_mini_batches * mini_batch_size, :]
        save_filename = os.path.join(Trainer.get_save_dir(self.model),
                                     f'latent_surface_{attr_str}.png')
        plot_dim(z, attr_labels_all, save_filename, dim1=dim1, dim2=dim2)
예제 #2
0
 def plot_attribute_surface(self, dim1=0, dim2=1, grid_res=0.5):
     """
     Plots the value of an attribute over a surface defined by the dimensions
     :param dim1: int,
     :param dim2: int,
     :param grid_res: float,
     :return:
     """
     # create the dataspace
     x1 = torch.arange(-5., 5., grid_res)
     x2 = torch.arange(-5., 5., grid_res)
     z1, z2 = torch.meshgrid([x1, x2])
     num_points = z1.size(0) * z1.size(1)
     z = torch.randn(1, self.model.latent_space_dim)
     z = z.repeat(num_points, 1)
     z[:, dim1] = z1.contiguous().view(1, -1)
     z[:, dim2] = z2.contiguous().view(1, -1)
     z = to_cuda_variable(z)
     # pass the points through the model decoder
     mini_batch_size = 500
     num_mini_batches = num_points // mini_batch_size
     nd_all = []
     nr_all = []
     rc_all = []
     # ie_all = []
     for i in tqdm(range(num_mini_batches)):
         z_batch = z[i*mini_batch_size:(i+1)*mini_batch_size, :]
         dummy_score_tensor = to_cuda_variable(torch.zeros(z_batch.size(0), self.measure_seq_len))
         _, samples = self.model.decoder(
             z=z_batch,
             score_tensor=dummy_score_tensor,
             train=self.train
         )
         samples = samples.view(z_batch.size(0), -1)
         note_density = self.dataset.get_note_density_in_measure(samples)
         note_range = self.dataset.get_pitch_range_in_measure(samples)
         rhy_complexity = self.dataset.get_rhy_complexity(samples)
         # interval_entropy = self.dataset.get_interval_entropy(samples)
         nd_all.append(note_density)
         nr_all.append(note_range)
         rc_all.append(rhy_complexity)
         # ie_all.append(interval_entropy)
     nd_all = to_numpy(torch.cat(nd_all, 0))
     nr_all = to_numpy(torch.cat(nr_all, 0))
     rc_all = to_numpy(torch.cat(rc_all, 0))
     # ie_all = to_numpy(torch.cat(ie_all, 0))
     z = to_numpy(z)
     if self.trainer_config == '':
         reg_str = '[no_reg]'
     else:
         reg_str = self.trainer_config
     filename = self.dir_path + '/plots/' + reg_str + 'attr_surf_note_density_[' \
                + str(dim1) + ',' + str(dim2) + '].png'
     self.plot_dim(z, nd_all, filename, dim1=dim1, dim2=dim2)
     filename = self.dir_path + '/plots/' + reg_str + 'attr_surf_note_range_[' \
                + str(dim1) + ',' + str(dim2) + '].png'
     self.plot_dim(z, nr_all, filename, dim1=dim1, dim2=dim2)
     filename = self.dir_path + '/plots/' + reg_str + 'attr_surf_rhy_complexity_[' \
                + str(dim1) + ',' + str(dim2) + '].png'
     self.plot_dim(z, rc_all, filename, dim1=dim1, dim2=dim2)
예제 #3
0
    def loss_and_acc_on_epoch(self, data_loader, epoch_num=None, train=True):
        """
        Computes the loss and accuracy for an epoch
        :param data_loader: torch dataloader object
        :param epoch_num: int, used to change training schedule
        :param train: bool, performs the backward pass and gradient descent if TRUE
        :return: loss values and accuracy percentages
        """
        mean_loss = 0
        mean_accuracy = 0
        for batch_num, batch in tqdm(enumerate(data_loader)):
            log = False
            if train and self.writer is not None:
                if self.cur_epoch_num != epoch_num:
                    log = True
                    self.cur_epoch_num = epoch_num
            # process batch data
            batch_data = self.process_batch_data(batch)
            inputs, labels = batch_data
            norm_labels = self.normalize_labels(labels.clone())
            flipped_norm_labels = 1.0 - norm_labels.clone()

            # Encode data
            z = self.model.encode(inputs)

            # TRAIN DISCRIMINATOR
            # zero the disc gradients
            self.disc_zero_grad()
            # compute loss for discriminator
            disc_loss = self.disc_loss_for_batch(z.detach(), norm_labels,
                                                 epoch_num, log)
            # compute backward and step discriminator if train
            if train:
                disc_loss.backward()
                self.disc_step()

            # TRAIN FADER MODEL
            # zero the model gradients
            self.zero_grad()
            # compute fader model loss
            fader_loss, accuracy = self.fader_loss_for_batch(
                inputs, z, norm_labels, flipped_norm_labels, epoch_num, log)
            # compute backward and step if train
            if train:
                fader_loss.backward()
                self.step()

            # compute mean loss and accuracy
            mean_loss += to_numpy(fader_loss.mean())
            if accuracy is not None:
                mean_accuracy += to_numpy(accuracy)

            # update beta
            self.curr_beta += self.beta_delta

        mean_loss /= len(data_loader)
        mean_accuracy /= len(data_loader)
        return (mean_loss, mean_accuracy)
예제 #4
0
    def test_interpretability(self, batch_size, attr_type):
        """
        Tests the interpretability of the latent space for a partcular attribute
        :param batch_size: int, number of datapoints in mini-batch
        :param attr_type: str, attribute type
        :return: tuple(int, float): index of dimension with highest mutual info, interpretability score
        """
        (_, gen_val, gen_test) = self.dataset.data_loaders(
            batch_size=batch_size,
            split=(0.01, 0.01)
        )

        # compute latent vectors and attribute values
        z_all = []
        attr_all = []
        for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(gen_test)):
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1)
            # convert input to torch Variables
            score_tensor, metadata_tensor = (
                to_cuda_variable_long(score_tensor),
                to_cuda_variable_long(metadata_tensor)
            )
            # compute encoder forward pass
            z_dist = self.model.encoder(score_tensor)
            # sample from distribution
            z_tilde = z_dist.rsample()

            # compute attributes
            if attr_type == 'rhy_complexity':
                attr = self.dataset.get_rhy_complexity(score_tensor)
            elif attr_type == 'num_notes':
                attr = self.dataset.get_note_density_in_measure(score_tensor)
            elif attr_type == 'note_range':
                attr = self.dataset.get_pitch_range_in_measure(score_tensor)
            z_all.append(to_numpy(z_tilde.cpu()))
            attr_all.append(to_numpy(attr.cpu()))
        z_all = np.concatenate(z_all)
        attr_all = np.concatenate(attr_all)

        # compute mutual information
        mutual_info = np.zeros(self.z_dim)
        for i in tqdm(range(self.z_dim)):
            mutual_info[i] = mutual_info_score(z_all[:, i], attr_all)
        dim = np.argmax(mutual_info)
        max_mutual_info = np.max(mutual_info)

        reg = LinearRegression().fit(z_all[:, dim:dim+1], attr_all)
        score = reg.score(z_all[:, dim:dim+1], attr_all)
        return dim, score
예제 #5
0
 def compute_representations(self, data_loader):
     latent_codes = []
     attributes = []
     for sample_id, batch in tqdm(enumerate(data_loader)):
         inputs, labels = self.process_batch_data(batch)
         _, _, _, z_tilde, _ = self.model(inputs)
         latent_codes.append(to_numpy(z_tilde.cpu()))
         attributes.append(to_numpy(labels))
         if sample_id == 200:
             break
     latent_codes = np.concatenate(latent_codes, 0)
     attributes = np.concatenate(attributes, 0)
     attributes, attr_list = self._extract_relevant_attributes(attributes)
     return latent_codes, attributes, attr_list
예제 #6
0
    def loss_and_acc_test(self, data_loader):
        """
        Computes loss and accuracy for test data
        :param data_loader: torch data loader object
        :return: (float, float)
        """
        mean_loss = 0
        mean_accuracy = 0

        for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(data_loader)):
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1)
            # convert input to torch Variables
            score_tensor, metadata_tensor = (
                to_cuda_variable_long(score_tensor),
                to_cuda_variable_long(metadata_tensor)
            )
            # compute forward pass
            weights, samples, _, _, _, _ = self.model(
                measure_score_tensor=score_tensor,
                measure_metadata_tensor=metadata_tensor,
                train=False
            )

            # compute loss
            recons_loss = MeasureVAETrainer.mean_crossentropy_loss(
                weights=weights,
                targets=score_tensor
            )
            loss = recons_loss
            # compute mean loss and accuracy
            mean_loss += to_numpy(loss.mean())
            accuracy = MeasureVAETrainer.mean_accuracy(
                weights=weights,
                targets=score_tensor
            )
            mean_accuracy += to_numpy(accuracy)
        mean_loss /= len(data_loader)
        mean_accuracy /= len(data_loader)
        return (
            mean_loss,
            mean_accuracy
        )
예제 #7
0
    def loss_and_acc_test(self, data_loader):
        mean_loss = 0
        mean_accuracy = 0

        for sample_id, batch in tqdm(enumerate(data_loader)):
            inputs, _ = self.process_batch_data(batch)
            inputs = to_cuda_variable(inputs)
            # compute forward pass
            outputs, _, _, _, _ = self.model(inputs)
            # compute loss
            recons_loss = self.reconstruction_loss(inputs, outputs,
                                                   self.dec_dist)
            loss = recons_loss
            # compute mean loss and accuracy
            mean_loss += to_numpy(loss.mean())
            accuracy = self.mean_accuracy(weights=torch.sigmoid(outputs),
                                          targets=inputs)
            mean_accuracy += to_numpy(accuracy)
        mean_loss /= len(data_loader)
        mean_accuracy /= len(data_loader)
        return (mean_loss, mean_accuracy)
예제 #8
0
    def plot_transposition_points(self, plt_type='pca'):
        """
        Plots a t-SNE plot for data-points comprising of transposed measures
        :param plt_type: str, 'tsne' or 'pca'
        :return:
        """
        filepaths = self.dataset.valid_filepaths
        idx = random.randint(0, len(filepaths))
        original_score = get_music21_score_from_path(filepaths[idx])
        possible_transpositions = self.dataset.all_transposition_intervals(original_score)
        z_all = []
        n_all = []
        n = 0
        for trans_int in possible_transpositions:
            score_tensor = self.dataset.get_transposed_tensor(
                original_score,
                trans_int
            )
            score_tensor = self.dataset.split_tensor_to_bars(score_tensor)
            score_tensor = to_cuda_variable_long(score_tensor)
            z_dist = self.model.encoder(score_tensor)
            z_tilde = z_dist.loc
            z_all.append(z_tilde)
            t = np.arange(0, z_tilde.size(0))
            n_all.append(torch.from_numpy(t))
            # n_all.append(torch.ones(z_tilde.size(0)) * n)
            n += 1
        print(n)
        z_all = torch.cat(z_all, 0)
        n_all = torch.cat(n_all, 0)
        z_all = to_numpy(z_all)
        n_all = to_numpy(n_all)

        filename = self.dir_path + '/plots/' + plt_type + '_transposition_measure_vae.png'
        if plt_type == 'pca':
            self.plot_pca(z_all, n_all, filename)
        elif plt_type == 'tsne':
            self.plot_tsne(z_all, n_all, filename)
        else:
            raise ValueError('Invalid plot type')
예제 #9
0
파일: trainer.py 프로젝트: zbxzc35/ar-vae
    def loss_and_acc_on_epoch(self, data_loader, epoch_num=None, train=True):
        """
        Computes the loss and accuracy for an epoch
        :param data_loader: torch dataloader object
        :param epoch_num: int, used to change training schedule
        :param train: bool, performs the backward pass and gradient descent if TRUE
        :return: loss values and accuracy percentages
        """
        mean_loss = 0
        mean_accuracy = 0
        for batch_num, batch in tqdm(enumerate(data_loader)):
            # process batch data
            batch_data = self.process_batch_data(batch)

            # zero the gradients
            self.zero_grad()

            # compute loss for batch
            loss, accuracy = self.loss_and_acc_for_batch(
                batch_data, epoch_num, batch_num, train=train
            )

            # compute backward and step if train
            if train:
                loss.backward()
                # self.plot_grad_flow()
                self.step()

            # compute mean loss and accuracy
            mean_loss += to_numpy(loss.mean())
            if accuracy is not None:
                mean_accuracy += to_numpy(accuracy)

        mean_loss /= len(data_loader)
        mean_accuracy /= len(data_loader)
        return (
            mean_loss,
            mean_accuracy
        )
예제 #10
0
    def _plot_data_attr_dist(self, gen_test, dim1, dim2, reg_type):
        z_all = []
        attr_all = []
        for sample_id, (score_tensor, metadata_tensor) in tqdm(enumerate(gen_test)):
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size, self.dataset.n_bars, -1)
                metadata_tensor = metadata_tensor.view(batch_size * self.dataset.n_bars, -1)
            # convert input to torch Variables
            score_tensor, metadata_tensor = (
                to_cuda_variable_long(score_tensor),
                to_cuda_variable_long(metadata_tensor)
            )
            # compute encoder forward pass
            z_dist = self.model.encoder(score_tensor)
            # sample from distribution
            z_tilde = z_dist.rsample()

            # compute attributes
            if reg_type == 'rhy_complexity':
                attr = self.dataset.get_rhy_complexity(score_tensor)
            elif reg_type == 'num_notes':
                attr = self.dataset.get_note_density_in_measure(score_tensor)
            elif reg_type == 'note_range':
                attr = self.dataset.get_pitch_range_in_measure(score_tensor)
            z_all.append(z_tilde)
            attr_all.append(attr)
        z_all = to_numpy(torch.cat(z_all, 0))
        attr_all = to_numpy(torch.cat(attr_all, 0))
        if self.trainer_config == '':
            reg_str = '[no_reg]'
        else:
            reg_str = self.trainer_config
        filename = self.dir_path + '/plots/' + reg_str + 'data_dist_' + reg_type + '_[' \
                   + str(dim1) + ',' + str(dim2) + '].png'
        self.plot_dim(z_all, attr_all, filename, dim1=dim1, dim2=dim2, xlim=6, ylim=6)
예제 #11
0
 def get_rhythmic_entropy(self, measure_tensor):
     """
     Returns the rhytmic entropy in a measure of music
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable,
             (batch_size)
     """
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     beat_tensor = measure_tensor.clone()
     beat_tensor[beat_tensor == start_index] = slur_index
     beat_tensor[beat_tensor == end_index] = slur_index
     beat_tensor[beat_tensor == rest_index] = slur_index
     beat_tensor[beat_tensor == none_index] = slur_index
     beat_tensor[beat_tensor != slur_index] = 1
     beat_tensor[beat_tensor == slur_index] = 0
     ent = stats.entropy(np.transpose(to_numpy(beat_tensor)))
     ent = torch.from_numpy(np.transpose(ent))
     ent = to_cuda_variable(ent)
     return ent
예제 #12
0
    def plot_attribute_dist(self, attribute='num_notes', plt_type='pca'):
        """
        Plots the distribution of a particular attribute in the latent space
        :param attribute: str,
                num_notes, note_range, rhy_entropy, beat_strength, rhy_complexity
        :param plt_type: str, 'tsne' or 'pca'
        :return:
        """
        (_, _, gen_test) = self.dataset.data_loaders(
            batch_size=64,  # TODO: remove this hard coding
            split=(0.01, 0.01)
        )
        z_all = []
        n_all = []
        num_samples = 5
        for sample_id, (score_tensor, _) in tqdm(enumerate(gen_test)):
            # convert input to torch Variables
            if isinstance(self.dataset, FolkNBarDataset):
                batch_size = score_tensor.size(0)
                score_tensor = score_tensor.view(batch_size, self.dataset.n_bars, -1)
                score_tensor = score_tensor.view(batch_size * self.dataset.n_bars, -1)
            score_tensor = to_cuda_variable_long(score_tensor)
            # compute encoder forward pass
            z_dist = self.model.encoder(score_tensor)
            z_tilde = z_dist.loc
            z_all.append(z_tilde)
            if attribute == 'num_notes':
                attr = self.dataset.get_note_density_in_measure(score_tensor)
            elif attribute == 'note_range':
                attr = self.dataset.get_pitch_range_in_measure(score_tensor)
            elif attribute == 'rhy_entropy':
                attr = self.dataset.get_rhythmic_entropy(score_tensor)
            elif attribute == 'beat_strength':
                attr = self.dataset.get_beat_strength(score_tensor)
            elif attribute == 'rhy_complexity':
                attr = self.dataset.get_rhy_complexity(score_tensor)
            else:
                raise ValueError('Invalid attribute type')
            for i in range(attr.size(0)):
                tensor_score = score_tensor[i, :]
                start_idx = self.dataset.note2index_dicts[START_SYMBOL]
                end_idx = self.dataset.note2index_dicts[END_SYMBOL]
                if tensor_score[0] == start_idx:
                    attr[i] = -0.1
                elif tensor_score[0] == end_idx:
                    attr[i] = -0.2
            n_all.append(attr)
            if sample_id == num_samples:
                break
        z_all = torch.cat(z_all, 0)
        n_all = torch.cat(n_all, 0)
        z_all = to_numpy(z_all)
        n_all = to_numpy(n_all)

        filename = self.dir_path + '/plots/' + plt_type + '_' + attribute + '_' + \
                   str(num_samples) + '_measure_vae.png'
        if plt_type == 'pca':
            self.plot_pca(z_all, n_all, filename)
        elif plt_type == 'tsne':
            self.plot_tsne(z_all, n_all, filename)
        elif plt_type == 'dim':
            self.plot_dim(z_all, n_all, filename)
        else:
            raise ValueError('Invalid plot type')