#!/usr/bin/env python3

import dataloader
import itertools
import nnmodules
import torch

# Check Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 16

tgt_enc = nnmodules.ResRnn(input_width=8,
                           state_width=500,
                           output_width=8,
                           checkpoint_name='tgt_enc')
tgt_dec = nnmodules.ResRnn(input_width=8,
                           state_width=500,
                           output_width=8,
                           checkpoint_name='tgt_dec')

tgt_enc = tgt_enc.to(device)
tgt_dec = tgt_dec.to(device)

optimizer = torch.optim.SGD(
    list(tgt_enc.parameters()) + \
    list(tgt_dec.parameters()),
    lr=1e+5,
    momentum=0.9,
)

Beispiel #2
0
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=train_batch_size,
                                           shuffle=True,
                                           pin_memory=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=test_batch_size,
                                          shuffle=False,
                                          pin_memory=True)

model = nnmodules.ResRnn(
    input_width=1,
    state_width=state_size,
    output_width=output_size,
    linearity=0.99999,
    checkpoint_name='pmnist',
).to(device)


# Loss and optimizer
def loss_fn(outputs, labels):
    one_hot = torch.nn.functional.one_hot(
        labels, num_classes=output_size).type(outputs.dtype)

    return torch.nn.functional.smooth_l1_loss(outputs, one_hot)


optimizer = torch.optim.SGD(model.parameters(), lr=10000.00, momentum=0.9)
Beispiel #3
0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
test_dataset = torchvision.datasets.MNIST(root='data',
                                          train=False,
                                          transform=transforms.ToTensor())

# Data loader
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=10000,
                                          shuffle=False,
                                          pin_memory=True)

model = nnmodules.ResRnn(
    input_width=1,
    state_width=1000,
    output_width=10,
    linearity=0.99999,
).to(device)

model.load('models/pmnist-1-1000-10-98.28.pt')

torch.manual_seed(0)
random_indices = torch.randperm(28 * 28)

# Train the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(images.size(0), 28 * 28)
        images = images[:, random_indices]