def forward(self, x, Psi): a = torch.arcsin(torch.sqrt(closure(x))) # B x D a = torch.log(closure(x + 1)) a = a - a.mean(axis=1).reshape(-1, 1) # center around mean x_ = a[:, :, None] * self.embed # B x D x H fx = self.ffn(x_).squeeze() fx = (Psi @ fx.T).T # B x D-1 return fx
def validation_step(self, batch, batch_idx): with torch.no_grad(): counts = batch if self.hparams['tss']: # only for benchmarking counts = closure(counts) loss = self.vae(counts) assert torch.isnan(loss).item() is False # Record the actual loss. rec_err = self.vae.get_reconstruction_loss(batch) tensorboard_logs = {'val_loss': loss, 'val_rec_err': rec_err} # log the learning rate return {'val_loss': loss, 'log': tensorboard_logs}
def training_step(self, batch, batch_idx): self.vae.train() counts = batch.to(self.device) if self.hparams['tss']: # only for benchmarking counts = closure(counts) loss = self.vae(counts) assert torch.isnan(loss).item() is False if len(self.trainer.lr_schedulers) >= 1: lr = self.trainer.lr_schedulers[0]['scheduler'].get_last_lr()[0] current_lr = lr else: current_lr = self.hparams['learning_rate'] tensorboard_logs = { 'train_loss': loss, 'elbo': -loss, 'lr': current_lr } # log the learning rate return {'loss': loss, 'log': tensorboard_logs}
def forward(self, x, Psi): a = torch.arcsin(torch.sqrt(closure(x))) # B x D x_ = a[:, :, None] * self.embed # B x D x H fx = self.ffn(x_).squeeze() fx = (Psi @ fx.T).T # B x D-1 return fx