def validation_inference_function(batch): model.eval() inputs, targets = batch inputs = flatten(Variable(inputs)) targets = Variable( to_one_hot(targets, batch_size=batch_size, n_classes=10)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() output, mu, logvar = model(inputs, targets) loss = criterion(output, inputs, mu, logvar) return loss.data[0]
def training_update_function(batch): model.train() optimizer.zero_grad() inputs, targets = batch inputs = flatten(Variable(inputs)) targets = Variable( to_one_hot(targets, batch_size=batch_size, n_classes=10)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() output, mu, logvar = model(inputs, targets) loss = criterion(output, inputs, mu, logvar) loss.backward() optimizer.step() return loss.data[0]
def cvae_reconstructions(model, loader, n=10): model.eval() inputs, targets = next(iter(loader)) batch_size = inputs.size(0) n = min(n, batch_size) inputs = Variable(inputs)[:n] # TODO: remove hardcoded n_classes targets = Variable(to_one_hot( targets[:n], batch_size=n, n_classes=10)) if use_gpu: inputs = inputs.cuda() targets = targets.cuda() reconstructions, _, _ = model(flatten(inputs), targets) reconstructions = reconstructions.view(reconstructions.size(0), 1, 28, 28) return vutils.make_grid(torch.cat([inputs.data, reconstructions.data]), n)
def encode(self, x): h1 = flatten(self.encoder(x)) return self.fc21(h1), self.fc22(h1)
def forward(self, x): mu, logvar = self.encode(flatten(x)) z = self.reparameterize(mu, logvar) decoded = self.decode(z) return decoded, mu, logvar