Exemplo n.º 1
0
 def forward(self, x, Psi):
     a = torch.arcsin(torch.sqrt(closure(x)))  # B x D
     a = torch.log(closure(x + 1))
     a = a - a.mean(axis=1).reshape(-1, 1)     # center around mean
     x_ = a[:, :, None] * self.embed           # B x D x H
     fx = self.ffn(x_).squeeze()
     fx = (Psi @ fx.T).T                       # B x D-1
     return fx
Exemplo n.º 2
0
    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            counts = batch
            if self.hparams['tss']:  # only for benchmarking
                counts = closure(counts)
            loss = self.vae(counts)
            assert torch.isnan(loss).item() is False

            # Record the actual loss.
            rec_err = self.vae.get_reconstruction_loss(batch)
            tensorboard_logs = {'val_loss': loss, 'val_rec_err': rec_err}
            # log the learning rate
            return {'val_loss': loss, 'log': tensorboard_logs}
Exemplo n.º 3
0
 def training_step(self, batch, batch_idx):
     self.vae.train()
     counts = batch.to(self.device)
     if self.hparams['tss']:  # only for benchmarking
         counts = closure(counts)
     loss = self.vae(counts)
     assert torch.isnan(loss).item() is False
     if len(self.trainer.lr_schedulers) >= 1:
         lr = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0]
         current_lr = lr
     else:
         current_lr = self.hparams['learning_rate']
     tensorboard_logs = {
         'train_loss': loss,
         'elbo': -loss,
         'lr': current_lr
     }
     # log the learning rate
     return {'loss': loss, 'log': tensorboard_logs}
Exemplo n.º 4
0
 def forward(self, x, Psi):
     a = torch.arcsin(torch.sqrt(closure(x)))  # B x D
     x_ = a[:, :, None] * self.embed           # B x D x H
     fx = self.ffn(x_).squeeze()
     fx = (Psi @ fx.T).T                         # B x D-1
     return fx