コード例 #1
0
    def validation_inference_function(batch):
        model.eval()

        inputs, targets = batch
        inputs = flatten(Variable(inputs))
        targets = Variable(
            to_one_hot(targets, batch_size=batch_size, n_classes=10))

        if use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()

        output, mu, logvar = model(inputs, targets)
        loss = criterion(output, inputs, mu, logvar)

        return loss.data[0]
コード例 #2
0
    def training_update_function(batch):
        model.train()
        optimizer.zero_grad()

        inputs, targets = batch
        inputs = flatten(Variable(inputs))
        targets = Variable(
            to_one_hot(targets, batch_size=batch_size, n_classes=10))

        if use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()

        output, mu, logvar = model(inputs, targets)
        loss = criterion(output, inputs, mu, logvar)
        loss.backward()
        optimizer.step()

        return loss.data[0]
コード例 #3
0
ファイル: sampling.py プロジェクト: jongold/autoencoders
def cvae_reconstructions(model, loader, n=10):
    model.eval()

    inputs, targets = next(iter(loader))

    batch_size = inputs.size(0)

    n = min(n, batch_size)

    inputs = Variable(inputs)[:n]

    # TODO: remove hardcoded n_classes
    targets = Variable(to_one_hot(
        targets[:n], batch_size=n, n_classes=10))

    if use_gpu:
        inputs = inputs.cuda()
        targets = targets.cuda()

    reconstructions, _, _ = model(flatten(inputs), targets)
    reconstructions = reconstructions.view(reconstructions.size(0), 1, 28, 28)

    return vutils.make_grid(torch.cat([inputs.data, reconstructions.data]), n)
コード例 #4
0
 def encode(self, x):
     h1 = flatten(self.encoder(x))
     return self.fc21(h1), self.fc22(h1)
コード例 #5
0
    def forward(self, x):
        mu, logvar = self.encode(flatten(x))
        z = self.reparameterize(mu, logvar)
        decoded = self.decode(z)

        return decoded, mu, logvar