Beispiel #1
0
 def plot_attribute_surface(self, dim1=0, dim2=1, grid_res=0.5):
     """
     Plots the value of an attribute over a surface defined by the dimensions
     :param dim1: int,
     :param dim2: int,
     :param grid_res: float,
     :return:
     """
     # create the dataspace
     x1 = torch.arange(-5., 5., grid_res)
     x2 = torch.arange(-5., 5., grid_res)
     z1, z2 = torch.meshgrid([x1, x2])
     num_points = z1.size(0) * z1.size(1)
     z = torch.randn(1, self.model.latent_space_dim)
     z = z.repeat(num_points, 1)
     z[:, dim1] = z1.contiguous().view(1, -1)
     z[:, dim2] = z2.contiguous().view(1, -1)
     z = to_cuda_variable(z)
     # pass the points through the model decoder
     mini_batch_size = 500
     num_mini_batches = num_points // mini_batch_size
     nd_all = []
     nr_all = []
     rc_all = []
     # ie_all = []
     for i in tqdm(range(num_mini_batches)):
         z_batch = z[i*mini_batch_size:(i+1)*mini_batch_size, :]
         dummy_score_tensor = to_cuda_variable(torch.zeros(z_batch.size(0), self.measure_seq_len))
         _, samples = self.model.decoder(
             z=z_batch,
             score_tensor=dummy_score_tensor,
             train=self.train
         )
         samples = samples.view(z_batch.size(0), -1)
         note_density = self.dataset.get_note_density_in_measure(samples)
         note_range = self.dataset.get_pitch_range_in_measure(samples)
         rhy_complexity = self.dataset.get_rhy_complexity(samples)
         # interval_entropy = self.dataset.get_interval_entropy(samples)
         nd_all.append(note_density)
         nr_all.append(note_range)
         rc_all.append(rhy_complexity)
         # ie_all.append(interval_entropy)
     nd_all = to_numpy(torch.cat(nd_all, 0))
     nr_all = to_numpy(torch.cat(nr_all, 0))
     rc_all = to_numpy(torch.cat(rc_all, 0))
     # ie_all = to_numpy(torch.cat(ie_all, 0))
     z = to_numpy(z)
     if self.trainer_config == '':
         reg_str = '[no_reg]'
     else:
         reg_str = self.trainer_config
     filename = self.dir_path + '/plots/' + reg_str + 'attr_surf_note_density_[' \
                + str(dim1) + ',' + str(dim2) + '].png'
     self.plot_dim(z, nd_all, filename, dim1=dim1, dim2=dim2)
     filename = self.dir_path + '/plots/' + reg_str + 'attr_surf_note_range_[' \
                + str(dim1) + ',' + str(dim2) + '].png'
     self.plot_dim(z, nr_all, filename, dim1=dim1, dim2=dim2)
     filename = self.dir_path + '/plots/' + reg_str + 'attr_surf_rhy_complexity_[' \
                + str(dim1) + ',' + str(dim2) + '].png'
     self.plot_dim(z, rc_all, filename, dim1=dim1, dim2=dim2)
Beispiel #2
0
 def get_resnet_accuracy(self):
     if self.dataset_type != 'mnist':
         return None
     # instantiate Resnet model
     resnet_model = MnistResNet()
     if torch.cuda.is_available():
         resnet_model.load()
         resnet_model.cuda()
     else:
         resnet_model.load(cpu=True)
     batch_size = 128
     _, _, data_loader = self.dataset.data_loaders(batch_size=batch_size)
     interp_dict = self.metrics["interpretability"]
     input_acc = 0
     recons_acc = 0
     interp_acc = 0
     for sample_id, batch in tqdm(enumerate(data_loader)):
         inputs, digit_labels, _ = batch
         inputs = to_cuda_variable(inputs)
         digit_labels = to_cuda_variable(digit_labels)
         recons, _, _, z, _ = self.model(inputs)
         recons = torch.sigmoid(recons)
         # compute input and reconstruction accuracy on resnet
         pred_inputs = self.compute_mnist_digit_identity(
             resnet_model, inputs)
         pred_recons = self.compute_mnist_digit_identity(
             resnet_model, recons)
         input_acc += self.mean_accuracy_pred(pred_inputs, digit_labels)
         recons_acc += self.mean_accuracy_pred(pred_recons, digit_labels)
         dummy = 0
         num_interps = 10
         z = z.repeat(num_interps, 1)
         for attr_str in interp_dict.keys():
             z_copy = z.clone()
             x1 = np.linspace(-4, 4.0, num_interps)
             x1 = x1.repeat(z.size(0) // num_interps)
             x1 = torch.from_numpy(x1)
             dim = interp_dict[attr_str][0]
             z_copy[:, dim] = x1.contiguous()
             outputs = torch.sigmoid(self.model.decode(z_copy))
             pred_outputs = self.compute_mnist_digit_identity(
                 resnet_model, outputs)
             repeated_labels = digit_labels.repeat(10)
             dummy += self.mean_accuracy_pred(pred_outputs, repeated_labels)
         interp_acc += dummy / len(interp_dict.keys())
     num_batches = sample_id + 1
     return {
         'digit_pred_acc': {
             'inputs': input_acc.item() / num_batches,
             'recons': recons_acc.item() / num_batches,
             'interp': interp_acc.item() / num_batches
         }
     }
Beispiel #3
0
 def test_attr_reg_interpolations(self, num_points=10, dim=0, num_interps=20):
     for i in range(num_points):
         z = torch.randn(1, self.model.latent_space_dim)
         z1 = z.clone()
         z2 = z.clone()
         z1[:, dim] = -3.
         z2[:, dim] = 3.
         z1 = to_cuda_variable(z1)
         z2 = to_cuda_variable(z2)
         tensor_score = self.decode_mid_point(z1, z2, num_interps)
         score = self.dataset.tensor_to_m21score(tensor_score.cpu())
         score.show()
Beispiel #4
0
 def compute_latent_interpolations(self, latent_code, labels, dim1=1):
     x1 = torch.arange(0., 1.01, 0.1)
     num_points = x1.size(0)
     z = to_cuda_variable(torch.from_numpy(latent_code[:1, :]))
     z = z.repeat(num_points, 1)
     l = labels[:1, :]
     l = l.repeat(num_points, 0)
     l[:, dim1] = x1.contiguous()
     l = to_cuda_variable(torch.from_numpy(l))
     inputs = torch.cat((z, l), 1)
     outputs = torch.sigmoid(self.model.decode(inputs))
     interp = make_grid(outputs.cpu(), nrow=1, pad_value=1.0)
     return interp
Beispiel #5
0
    def plot_latent_surface(self, attr_str, dim1=0, dim2=1, grid_res=0.1):
        # create the dataspace
        x1 = torch.arange(-5., 5., grid_res)
        x2 = torch.arange(-5., 5., grid_res)
        z1, z2 = torch.meshgrid([x1, x2])
        num_points = z1.size(0) * z1.size(1)
        z = torch.randn(1, self.model.z_dim)
        z = z.repeat(num_points, 1)
        z[:, dim1] = z1.contiguous().view(1, -1)
        z[:, dim2] = z2.contiguous().view(1, -1)
        z = to_cuda_variable(z)

        mini_batch_size = 500
        num_mini_batches = num_points // mini_batch_size
        attr_labels_all = []
        for i in tqdm(range(num_mini_batches)):
            z_batch = z[i * mini_batch_size:(i + 1) * mini_batch_size, :]
            outputs = torch.sigmoid(self.model.decode(z_batch))
            labels = self.compute_mnist_morpho_labels(outputs, attr_str)
            attr_labels_all.append(torch.from_numpy(labels))
        attr_labels_all = to_numpy(torch.cat(attr_labels_all, 0))
        z = to_numpy(z)[:num_mini_batches * mini_batch_size, :]
        save_filename = os.path.join(Trainer.get_save_dir(self.model),
                                     f'latent_surface_{attr_str}.png')
        plot_dim(z, attr_labels_all, save_filename, dim1=dim1, dim2=dim2)
Beispiel #6
0
 def get_pitch_range_in_measure(self, measure_tensor):
     """
     Returns the note range of each measure of the input normalized by the range
     the dataset
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable containing float tensor ,
             (batch_size)
     """
     batch_size, measure_seq_len = measure_tensor.size()
     index2note = self.index2note_dicts
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     nrange = to_cuda_variable(torch.zeros(batch_size))
     for i in range(batch_size):
         midi_notes = []
         for j in range(measure_seq_len):
             index = measure_tensor[i][j].item()
             if index not in (slur_index, rest_index, none_index,
                              start_index, end_index):
                 midi_note = torch.Tensor(
                     [music21.pitch.Pitch(index2note[index]).midi])
                 midi_notes.append(midi_note)
         if len(midi_notes) == 0 or len(midi_notes) == 1:
             nrange[i] = 0
         else:
             midi_notes = torch.cat(midi_notes, 0)
             midi_notes = to_cuda_variable_long(midi_notes)
             nrange[i] = torch.max(midi_notes) - torch.min(midi_notes)
     return nrange / 26
Beispiel #7
0
 def get_rhy_complexity(self, measure_tensor):
     """
     Returns the normalized rhythmic complexity of a batch of measures
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable,
             (batch_size)
     """
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     beat_tensor = torch.ones_like(measure_tensor)
     beat_tensor[measure_tensor == slur_index] = 0
     beat_tensor[measure_tensor == start_index] = 0
     beat_tensor[measure_tensor == end_index] = 0
     beat_tensor[measure_tensor == none_index] = 0
     beat_tensor[measure_tensor == rest_index] = 0
     beat_tensor = beat_tensor.float()
     num_measures = measure_tensor.shape[0]
     weights = to_cuda_variable(RHY_COMPLEXITY_COEFFS.view(1, -1).float())
     norm_coeff = torch.sum(weights, dim=1)
     weights = weights.repeat(num_measures, 1)
     h_product = weights * beat_tensor
     b_str = torch.sum(h_product, dim=1) / norm_coeff
     return b_str
Beispiel #8
0
 def get_contour(self, measure_tensor):
     """
     Returns the direction of the melodic contour of a batch of measures
     :param measure_tensor: torch Variable
             (batch_size, measure_seq_len)
     :return torch.Variable
             (batch_size)
     """
     batch_size, measure_seq_len = measure_tensor.size()
     index2note = self.index2note_dicts
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     contour = to_cuda_variable(torch.zeros(batch_size))
     for i in range(batch_size):
         midi_notes = []
         for j in range(measure_seq_len):
             index = measure_tensor[i][j].item()
             if index not in (slur_index, rest_index, none_index,
                              start_index, end_index):
                 midi_note = torch.Tensor(
                     [music21.pitch.Pitch(index2note[index]).midi])
                 midi_notes.append(midi_note)
         if len(midi_notes) == 0 or len(midi_notes) == 1:
             contour[i] = 0
         else:
             midi_notes = torch.cat(midi_notes, 0)
             a = midi_notes[1:] - midi_notes[:-1]
             b = torch.sum(a)
             contour[i] = to_cuda_variable_long(b)
     return contour / 26
Beispiel #9
0
 def process_batch_data(self, batch):
     """
     Processes the batch returned by the dataloader iterator
     :param batch: object returned by the dataloader iterator
     :return: tuple of Torch Variable objects
     """
     if self.dataset_type == 'mnist':
         inputs, _, morpho_labels = batch
         inputs = to_cuda_variable(inputs)
         morpho_labels = to_cuda_variable(morpho_labels)
         return inputs, morpho_labels
     else:
         inputs, labels = batch
         inputs = to_cuda_variable(inputs)
         labels = to_cuda_variable(labels)
         return inputs, labels
Beispiel #10
0
    def forward_test(self, measure_score_tensor: Variable):
        """
        Implements the forward pass of the VAE
        :param measure_score_tensor: torch Variable,
                (batch_size, num_measures, measure_seq_length)
        :return: torch Variable,
                (batch_size, measure_seq_length, self.num_notes)
        """
        # check input
        batch_size, num_measures, seq_len = measure_score_tensor.size()
        assert (seq_len == self.num_ticks_per_measure)

        # compute output of encoding layer
        z = []
        for i in range(num_measures):
            z_dist = self.encoder(measure_score_tensor[:, i, :])
            z.append(z_dist.rsample().unsqueeze(1))
        z_tilde = torch.cat(z, 1)

        # compute output of decoding layer
        weights = []
        samples = []
        dummy_measure_tensor = to_cuda_variable(
            torch.zeros(batch_size, seq_len))
        for i in range(num_measures):
            w, s = self.decoder(z=z_tilde[:, i, :],
                                score_tensor=dummy_measure_tensor,
                                train=False)
            samples.append(s)
            weights.append(w.unsqueeze(1))
        samples = torch.cat(samples, 2)
        weights = torch.cat(weights, 1)
        return weights, samples
Beispiel #11
0
 def compute_grad_attr(self, softmax_weights):
     mask = to_cuda_variable(self.note_tensor[None, None, :].expand(
         softmax_weights.size()).detach())
     if self.reg_type == 'rhy_complexity':
         metrical_weights = RHY_COMPLEXITY_COEFFS
         metrical_weights = to_cuda_variable(
             metrical_weights[None, :, None].expand(
                 softmax_weights.size()).detach()).float()
         rhy_complexity = (softmax_weights * metrical_weights *
                           mask).sum(2).sum(1) / sum(RHY_COMPLEXITY_COEFFS)
         return rhy_complexity
     elif self.reg_type == 'num_notes':
         measure_seq_len = softmax_weights.size(1)
         num_notes = (softmax_weights *
                      mask).sum(2).sum(1) / measure_seq_len
         return num_notes
     else:
         raise ValueError('Invalid regularization type')
Beispiel #12
0
 def hidden_init(self, batch_size):
     """
     Initializes the hidden state of the encoder GRU
     :param batch_size: int
     :return: torch tensor,
            (self.num_encoder_layers x self.num_directions, batch_size, self.encoder_hidden_size)
     """
     hidden = torch.zeros(self.num_layers * self.num_directions, batch_size,
                          self.rnn_hidden_size)
     return to_cuda_variable(hidden)
Beispiel #13
0
    def hidden_init(self, batch_size):
        """

        :param batch_size: int,
        :return: torch tensor,
                (self.num_layers, batch_size, self.rnn_hidden_size)
        """
        h = to_cuda_variable(
            torch.zeros(self.num_layers, batch_size, self.rnn_hidden_size))
        return h
Beispiel #14
0
 def get_interval_entropy(self, measure_tensor):
     """
     Returns the normalized interval entropy of a batch of measures
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable,
             (batch_size)
     """
     batch_size, measure_seq_len = measure_tensor.size()
     index2note = self.index2note_dicts
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     has_note = False
     nrange = to_cuda_variable(torch.zeros(batch_size))
     for i in range(batch_size):
         midi_notes = []
         for j in range(measure_seq_len):
             index = measure_tensor[i][j].item()
             if index not in (slur_index, rest_index, none_index,
                              start_index, end_index):
                 midi_note = torch.Tensor(
                     [music21.pitch.Pitch(index2note[index]).midi])
                 midi_notes.append(midi_note)
         if len(midi_notes) == 0 or len(midi_notes) == 1:
             nrange[i] = 0
         else:
             midi_notes = torch.cat(midi_notes, 0)
             midi_notes = to_cuda_variable_long(midi_notes)
             # compute intervals
             midi_notes = torch.abs(midi_notes[1:] - midi_notes[:-1])
             midi_notes = midi_notes % 12
             probs = to_cuda_variable(torch.zeros([12, 1]))
             for k in range(len(midi_notes)):
                 probs[midi_notes[k]] += 1
             # compute entropy of this interval vector
             b = nn.functional.softmax(probs, dim=0) * \
                 nn.functional.log_softmax(probs, dim=0)
             b = -1.0 * b.sum()
             nrange[i] = b
     return nrange
Beispiel #15
0
    def test_interp(self):
        """
        Tests the interpolation capabilities of the latent space
        :return: None
        """
        (_, gen_val, gen_test) = self.dataset.data_loaders(
            batch_size=1,  # TODO: remove this hard coding
            split=(0.01, 0.5)
        )
        gen_it_test = gen_test.__iter__()
        for _ in range(random.randint(0, len(gen_test))):
            tensor_score1, _ = next(gen_it_test)

        gen_it_val = gen_val.__iter__()
        for _ in range(random.randint(0, len(gen_val))):
            tensor_score2, _ = next(gen_it_val)

        tensor_score1 = to_cuda_variable(tensor_score1.long())
        tensor_score2 = to_cuda_variable(tensor_score2.long())
        self.test_interpolation(tensor_score1, tensor_score2, 10)
Beispiel #16
0
 def compute_latent_interpolations(self,
                                   latent_code,
                                   dim1=0,
                                   num_points=10):
     x1 = torch.linspace(-4., 4.0, num_points)
     num_points = x1.size(0)
     z = to_cuda_variable(torch.from_numpy(latent_code))
     z = z.repeat(num_points, 1)
     z[:, dim1] = x1.contiguous()
     outputs = torch.sigmoid(self.model.decode(z))
     interp = make_grid(outputs.cpu(), nrow=num_points, pad_value=1.0)
     return interp
Beispiel #17
0
    def __init__(
        self,
        dataset,
        model: MnistVAE,
        lr=1e-4,
        reg_type: Tuple[str] = None,
        reg_dim: Tuple[int] = 0,
        dec_dist='bernoulli',
        beta=4.0,
        gamma=10.0,
        capacity=0.0,
        rand=0,
        delta=1.0,
    ):
        super(ImageVAETrainer, self).__init__(dataset, model, lr)
        if dataset.__class__.__name__ == 'MorphoMnistDataset':
            self.dataset_type = 'mnist'
        elif dataset.__class__.__name__ == 'DspritesDataset':
            self.dataset_type = 'dsprites'
        else:
            raise ValueError(
                f"Dataset type not recognized: {dataset.__class__.__name__}")
        self.attr_dict = DATASET_REG_TYPE_DICT[self.dataset_type]

        self.reverse_attr_dict = {v: k for k, v in self.attr_dict.items()}
        self.metrics = {}
        self.beta = beta
        self.capacity = to_cuda_variable(torch.FloatTensor([capacity]))
        self.gamma = 0.0
        self.delta = 0.0
        self.cur_epoch_num = 0
        self.warm_up_epochs = 10
        self.reg_type = reg_type
        self.reg_dim = ()
        self.use_reg_loss = False
        self.rand_seed = rand
        torch.manual_seed(self.rand_seed)
        np.random.seed(self.rand_seed)
        self.trainer_config = f'_r_{self.rand_seed}_b_{self.beta}_'
        if capacity != 0.0:
            self.trainer_config += f'c_{capacity}_'
        self.model.update_trainer_config(self.trainer_config)
        self.dec_dist = dec_dist
        if len(self.reg_type) != 0:
            self.use_reg_loss = True
            self.reg_dim = reg_dim
            self.gamma = gamma
            self.delta = delta
            self.trainer_config += f'g_{self.gamma}_d_{self.delta}_'
            reg_type_str = '_'.join(self.reg_type)
            self.trainer_config += f'{reg_type_str}_'
            self.model.update_trainer_config(self.trainer_config)
Beispiel #18
0
 def plot_latent_interpolations2d(self,
                                  attr_str1,
                                  attr_str2,
                                  num_points=10):
     x1 = torch.linspace(-4., 4.0, num_points)
     x2 = torch.linspace(-4., 4.0, num_points)
     z1, z2 = torch.meshgrid([x1, x2])
     total_num_points = z1.size(0) * z1.size(1)
     _, _, data_loader = self.dataset.data_loaders(batch_size=1)
     interp_dict = self.compute_eval_metrics()["interpretability"]
     dim1 = interp_dict[attr_str1][0]
     dim2 = interp_dict[attr_str2][0]
     for sample_id, batch in tqdm(enumerate(data_loader)):
         if sample_id == 9:
             inputs, labels = self.process_batch_data(batch)
             inputs = to_cuda_variable(inputs)
             recons, _, _, z, _ = self.model(inputs)
             recons = torch.sigmoid(recons)
             z = z.repeat(total_num_points, 1)
             z[:, dim1] = z1.contiguous().view(1, -1)
             z[:, dim2] = z2.contiguous().view(1, -1)
             # z = torch.flip(z, dims=[0])
             outputs = torch.sigmoid(self.model.decode(z))
             save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'latent_interpolations_2d_({attr_str1},{attr_str2})_{sample_id}.png'
             )
             save_image(outputs.cpu(),
                        save_filepath,
                        nrow=num_points,
                        pad_value=1.0)
             # save original image
             org_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'original_{sample_id}.png')
             save_image(inputs.cpu(),
                        org_save_filepath,
                        nrow=1,
                        pad_value=1.0)
             # save reconstruction
             recons_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'recons_{sample_id}.png')
             save_image(recons.cpu(),
                        recons_save_filepath,
                        nrow=1,
                        pad_value=1.0)
         if sample_id == 10:
             break
Beispiel #19
0
    def compute_reg_loss(self, z, score, epsilon=1e-3):
        """
        Compute the GLSR regularization loss
        :param z:
        :param score:
        :param epsilon:
        :return:
        """
        d_z = torch.zeros_like(z)
        deltas = (1 + torch.rand(z.size(0))) * epsilon
        deltas = to_cuda_variable(deltas)
        d_z[:self.reg_dim] = deltas
        z_plus = z + d_z
        z_minus = z - d_z

        dummy_score_tensor = to_cuda_variable(
            torch.zeros(z.size(0), self.model.num_ticks_per_measure))
        weights_plus, _ = self.model.decoder(z=z_plus,
                                             score_tensor=dummy_score_tensor,
                                             train=False)
        weights_minus, _ = self.model.decoder(z=z_minus,
                                              score_tensor=dummy_score_tensor,
                                              train=False)

        softmax_weights_plus = F.softmax(weights_plus, dim=2)
        softmax_weights_minus = F.softmax(weights_minus, dim=2)
        grad_softmax = softmax_weights_plus - softmax_weights_minus

        grad_attr = self.compute_grad_attr(grad_softmax)
        grad_attr = grad_attr / (2 * deltas)

        prior_mean = to_cuda_variable(torch.ones_like(grad_attr) * 100)
        prior_std = to_cuda_variable(torch.ones_like(grad_attr))
        reg_loss = -torch.distributions.Normal(prior_mean,
                                               prior_std).log_prob(grad_attr)
        return reg_loss.mean()
Beispiel #20
0
 def compute_latent_interpolations2d(self,
                                     latent_code,
                                     dim1=0,
                                     dim2=1,
                                     num_points=10):
     x1 = torch.linspace(-4., 4.0, num_points)
     x2 = torch.linspace(-4., 4.0, num_points)
     z1, z2 = torch.meshgrid([x1, x2])
     num_points = z1.size(0) * z1.size(1)
     z = to_cuda_variable(torch.from_numpy(latent_code))
     z = z.repeat(num_points, 1)
     z[:, dim1] = z1.contiguous().view(1, -1)
     z[:, dim2] = z2.contiguous().view(1, -1)
     # z = torch.flip(z, dims=[0])
     outputs = torch.sigmoid(self.model.decode(z))
     interp = make_grid(outputs.cpu(), nrow=z1.size(0), pad_value=1.0)
     return interp
Beispiel #21
0
 def plot_latent_interpolations(self, attr_str='slant', num_points=10):
     x1 = torch.linspace(-4, 4.0, num_points)
     _, _, data_loader = self.dataset.data_loaders(batch_size=1)
     interp_dict = self.compute_eval_metrics()["interpretability"]
     dim = interp_dict[attr_str][0]
     for sample_id, batch in tqdm(enumerate(data_loader)):
         # for MNIST [5, 1, 30, 19, 23, 21, 17, 61, 9, 28]
         if sample_id in [5, 1, 30, 19, 23, 21, 17, 61, 9, 28]:
             inputs, labels = self.process_batch_data(batch)
             inputs = to_cuda_variable(inputs)
             recons, _, _, z, _ = self.model(inputs)
             recons = torch.sigmoid(recons)
             z = z.repeat(num_points, 1)
             z[:, dim] = x1.contiguous()
             outputs = torch.sigmoid(self.model.decode(z))
             # save interpolation
             save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'latent_interpolations_{attr_str}_{sample_id}.png')
             save_image(outputs.cpu(),
                        save_filepath,
                        nrow=num_points,
                        pad_value=1.0)
             # save original image
             org_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'original_{sample_id}.png')
             save_image(inputs.cpu(),
                        org_save_filepath,
                        nrow=1,
                        pad_value=1.0)
             # save reconstruction
             recons_save_filepath = os.path.join(
                 Trainer.get_save_dir(self.model),
                 f'recons_{sample_id}.png')
             save_image(recons.cpu(),
                        recons_save_filepath,
                        nrow=1,
                        pad_value=1.0)
         if sample_id == 62:
             break
Beispiel #22
0
    def loss_and_acc_test(self, data_loader):
        mean_loss = 0
        mean_accuracy = 0

        for sample_id, batch in tqdm(enumerate(data_loader)):
            inputs, _ = self.process_batch_data(batch)
            inputs = to_cuda_variable(inputs)
            # compute forward pass
            outputs, _, _, _, _ = self.model(inputs)
            # compute loss
            recons_loss = self.reconstruction_loss(inputs, outputs,
                                                   self.dec_dist)
            loss = recons_loss
            # compute mean loss and accuracy
            mean_loss += to_numpy(loss.mean())
            accuracy = self.mean_accuracy(weights=torch.sigmoid(outputs),
                                          targets=inputs)
            mean_accuracy += to_numpy(accuracy)
        mean_loss /= len(data_loader)
        mean_accuracy /= len(data_loader)
        return (mean_loss, mean_accuracy)
Beispiel #23
0
 def decode_mid_point(self, z1, z2, n):
     """
     Decodes the mid-point of two latent vectors
     :param z1: torch tensor, (1, self.z_dim)
     :param z2: torch tensor, (1, self.z_dim)
     :param n: int, number of points for interpolation
     :return: torch tensor, (1, (n+2) * measure_seq_len)
     """
     assert(n >= 1 and isinstance(n, int))
     # compute the score_tensors for z1 and z2
     dummy_score_tensor = to_cuda_variable(torch.zeros(self.batch_size, self.measure_seq_len))
     _, sam1 = self.decoder(z1, dummy_score_tensor, self.train)
     _, sam2 = self.decoder(z2, dummy_score_tensor, self.train)
     # find the interpolation points and run through decoder
     tensor_score = sam1
     for i in range(n):
         z_interp = z1 + (z2 - z1)*(i+1)/(n+1)
         _, sam_interp = self.decoder(z_interp, dummy_score_tensor, self.train)
         tensor_score = torch.cat((tensor_score, sam_interp), 1)
     tensor_score = torch.cat((tensor_score, sam2), 1).view(1, -1)
     # score = self.dataset.tensor_to_score(tensor_score.cpu())
     return tensor_score
Beispiel #24
0
 def get_note_density_in_measure(self, measure_tensor):
     """
     Returns the number of notes in each measure of the input normalized by the
     length of the length of the measure representation
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable containing float tensor ,
             (batch_size)
     """
     _, measure_seq_len = measure_tensor.size()
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     slur_count = torch.sum(measure_tensor == slur_index, 1)
     rest_count = torch.sum(measure_tensor == rest_index, 1)
     start_count = torch.sum(measure_tensor == start_index, 1)
     end_count = torch.sum(measure_tensor == end_index, 1)
     note_count = measure_seq_len - (slur_count + rest_count + start_count +
                                     end_count)
     note_density = to_cuda_variable(note_count.float() / measure_seq_len)
     return note_density
Beispiel #25
0
 def plot_latent_reconstructions(self, num_points=10):
     _, _, data_loader = self.dataset.data_loaders(batch_size=num_points)
     for sample_id, batch in tqdm(enumerate(data_loader)):
         inputs, labels = self.process_batch_data(batch)
         inputs = to_cuda_variable(inputs)
         recons, _, _, z, _ = self.model(inputs)
         recons = torch.sigmoid(recons)
         # save original image
         org_save_filepath = os.path.join(Trainer.get_save_dir(self.model),
                                          f'r_original_{sample_id}.png')
         save_image(inputs.cpu(),
                    org_save_filepath,
                    nrow=num_points,
                    pad_value=1.0)
         # save reconstruction
         recons_save_filepath = os.path.join(
             Trainer.get_save_dir(self.model), f'r_recons_{sample_id}.png')
         save_image(recons.cpu(),
                    recons_save_filepath,
                    nrow=num_points,
                    pad_value=1.0)
         break
Beispiel #26
0
    def create_latent_gifs(self, sample_id=9, num_points=10):
        x1 = torch.linspace(-4, 4.0, num_points)
        _, _, data_loader = self.dataset.data_loaders(batch_size=1)
        interp_dict = self.compute_eval_metrics()["interpretability"]
        for sid, batch in tqdm(enumerate(data_loader)):
            if sid == sample_id:
                inputs, labels = self.process_batch_data(batch)
                inputs = to_cuda_variable(inputs)
                _, _, _, z, _ = self.model(inputs)
                z = z.repeat(num_points, 1)
                outputs = []
                for attr_str in self.attr_dict.keys():
                    if attr_str == 'digit_identity' or attr_str == 'color':
                        continue
                    dim = interp_dict[attr_str][0]
                    z_copy = z.clone()
                    z_copy[:, dim] = x1.contiguous()

                    outputs.append(torch.sigmoid(self.model.decode(z_copy)))
                outputs = torch.unsqueeze(torch.cat(outputs, dim=1), dim=2)
                interps = []
                for n in range(outputs.shape[0]):
                    image_grid = make_grid(outputs[n],
                                           padding=2,
                                           pad_value=1.0).detach().cpu()
                    np_image = image_grid.mul(255).clamp(0,
                                                         255).byte().permute(
                                                             1, 2, 0).numpy()
                    interps.append(Image.fromarray(np_image))
                # save gif
                gif_filepath = os.path.join(
                    Trainer.get_save_dir(self.model),
                    f'gif_interpolations_{self.dataset_type}_{sample_id}.gif')
                save_gif_from_list(interps, gif_filepath)
            if sid > sample_id:
                break
Beispiel #27
0
 def get_rhythmic_entropy(self, measure_tensor):
     """
     Returns the rhytmic entropy in a measure of music
     :param measure_tensor: torch Variable,
             (batch_size, measure_seq_len)
     :return: torch Variable,
             (batch_size)
     """
     slur_index = self.note2index_dicts[SLUR_SYMBOL]
     start_index = self.note2index_dicts[START_SYMBOL]
     end_index = self.note2index_dicts[END_SYMBOL]
     rest_index = self.note2index_dicts['rest']
     none_index = self.note2index_dicts[None]
     beat_tensor = measure_tensor.clone()
     beat_tensor[beat_tensor == start_index] = slur_index
     beat_tensor[beat_tensor == end_index] = slur_index
     beat_tensor[beat_tensor == rest_index] = slur_index
     beat_tensor[beat_tensor == none_index] = slur_index
     beat_tensor[beat_tensor != slur_index] = 1
     beat_tensor[beat_tensor == slur_index] = 0
     ent = stats.entropy(np.transpose(to_numpy(beat_tensor)))
     ent = torch.from_numpy(np.transpose(ent))
     ent = to_cuda_variable(ent)
     return ent