Exemplo n.º 1
0
    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        z = self.encoder(x)
        y = self.decoder(x, z)
        z_logits = self.discriminator(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(
            z_logits,
            torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item()

        self.eval_total.add(total_loss)

        return total_loss
    def eval_batch(self, x, x_aug, x_midi, dset_num):
        x, x_aug = x.float(), x_aug.float()
        
        assert(dset_num is not None)

        z = self.encoder(x)
        y = self.decoders[dset_num](x, z)
        z_logits = self.discriminator(z)

        z_classification = torch.max(z_logits, dim=1)[1]

        z_accuracy = (z_classification == dset_num).float().mean()

        self.eval_d_right.add(z_accuracy.data.item())

        # discriminator_right = F.cross_entropy(z_logits, dset_num).mean()
        discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        recon_loss = cross_entropy_loss(y, x)

        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
        
        total_loss = discriminator_right.data.item() * self.args.d_lambda + \
                     recon_loss.mean().data.item()
                     
        aligned_loss = 0.0
        if x_midi is not None:
            h, _  = self.midi_encoder(x_midi) # size : (bs, hidden_size)
            h = h.view(z.shape)
            aligned_loss = F.mse_loss(h, z)
            total_loss += self.args.m_lambda * aligned_loss.mean().data.item()


        self.eval_total.add(total_loss)

        return total_loss
    def train_batch(self, x, x_aug, x_midi=None, dset_num=None):
        # print(x)
        x, x_aug= x.float(), x_aug.float()
        assert(dset_num is not None)
        # Optimize D - discriminator right
        z = self.encoder(x)
        z_logits = self.discriminator(z)
        discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.args.d_lambda
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z = self.encoder(x_aug)
        y = self.decoders[dset_num](x, z)
        z_logits = self.discriminator(z)
        discriminator_wrong = - F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
        
        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)
        
        aligned_loss = 0.0
        if x_midi is not None:
            # x_midi = x_midi.cpu()
            h, _  = self.midi_encoder(x_midi) # size : (bs, hidden_size)
            h = h.view(z.shape)
            # print(">>>>>>>>>>>>>>>WOOOOOOHOOOOO<<<<<<<<<<<<<<<<<<<<")
            # print(x_midi.shape)
            # print(h.shape)
            # print(z.shape)
            # either have a discriminator or have a L2 loss
            aligned_loss = F.mse_loss(h, z)
            loss += self.args.m_lambda * aligned_loss


        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            for decoder in self.decoders:
                clip_grad_value_(decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()
Exemplo n.º 4
0
    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # Optimize D - discriminator right
        z = self.encoder(x)
        z_logits = self.discriminator(z)
        discriminator_right = F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()
        loss = discriminator_right * self.args.d_lambda
        self.d_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.discriminator.parameters(), self.args.grad_clip)

        self.d_optimizer.step()

        # optimize G - reconstructs well, discriminator wrong
        z = self.encoder(x_aug)
        if self.args.distributed:
            y = self.decoder(x, z)
        else:
            y = self.decoders[dset_num](x, z)
        z_logits = self.discriminator(z)
        discriminator_wrong = - F.cross_entropy(z_logits, torch.tensor([dset_num] * x.size(0)).long().cuda()).mean()

        if not (-100 < discriminator_right.data.item() < 100):
            self.logger.debug(f'z_logits: {z_logits.detach().cpu().numpy()}')
            self.logger.debug(f'dset_num: {dset_num}')

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = (recon_loss.mean() + self.args.d_lambda * discriminator_wrong)

        if self.args.distributed:
            self.model_optimizer.zero_grad()
        else:
            self.model_optimizers[dset_num].zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.encoder.parameters(), self.args.grad_clip)
            if self.args.distributed:
                clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
            else:
                for decoder in self.decoders:
                    clip_grad_value_(decoder.parameters(), self.args.grad_clip)
        ## BUGFIX model optimizer ##
        if self.args.distributed:
            self.model_optimizer.step()
        else:
            self.model_optimizers[dset_num].step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()
Exemplo n.º 5
0
    def eval_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()
        z = self.encoder(x)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.evals_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        total_loss = recon_loss.mean().data.item()
        self.eval_total.add(total_loss)

        return total_loss
Exemplo n.º 6
0
    def train_batch(self, x, x_aug, dset_num):
        'train batch without considering the discriminator'
        x = x.float()
        x_aug = x_aug.float()
        z = self.encoder(x_aug)
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())
        loss = recon_loss.mean()

        self.model_optimizer.zero_grad()
        loss.backward()
        self.model_optimizer.step()
        self.loss_total.add(loss.data.item())

        return loss.data.item()
Exemplo n.º 7
0
    def train_batch(self, x, x_aug, dset_num):
        x, x_aug = x.float(), x_aug.float()

        # optimize G - reconstructs well
        z = self.encoder(x_aug)
        z = z.detach()  # stop gradients
        y = self.decoder(x, z)

        recon_loss = cross_entropy_loss(y, x)
        self.losses_recon[dset_num].add(recon_loss.data.cpu().numpy().mean())

        loss = recon_loss.mean()
        self.model_optimizer.zero_grad()
        loss.backward()
        if self.args.grad_clip is not None:
            clip_grad_value_(self.decoder.parameters(), self.args.grad_clip)
        self.model_optimizer.step()

        self.loss_total.add(loss.data.item())

        return loss.data.item()