Ejemplo n.º 1
0
# Training parameters
num_train  = 2000
num_test   = 1000
batch_size = 100
num_epochs = 20
learn_rate = 1e-4
l2_reg     = 0.

# Initialize the MPS module
mps = MPS(input_dim=28**2, output_dim=10, bond_dim=bond_dim, 
          adaptive_mode=adaptive_mode, periodic_bc=periodic_bc)

# Set our loss function and optimizer
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(), lr=learn_rate, 
                             weight_decay=l2_reg)

# Get the training and test sets
transform = transforms.ToTensor()
train_set = datasets.MNIST('./mnist', download=True, transform=transform)
test_set = datasets.MNIST('./mnist', download=True, transform=transform, 
                          train=False)

# Put MNIST data into dataloaders
samplers = {'train': torch.utils.data.SubsetRandomSampler(range(num_train)),
            'test': torch.utils.data.SubsetRandomSampler(range(num_test))}
loaders = {name: torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
           sampler=samplers[name], drop_last=True) for (name, dataset) in 
           [('train', train_set), ('test', test_set)]}
num_batches = {name: total_num // batch_size for (name, total_num) in
Ejemplo n.º 2
0
#!/usr/bin/env python3
import torch
import sys

sys.path.append('/home/jemis/torch_mps')
from torchmps import MPS

batch_size = 11
input_size = 21
output_dim = 4
bond_dim = 5

input_data = torch.randn([batch_size, input_size])

# For both open and periodic boundary conditions, place the label site in 
# different locations and check that the basic behavior is correct
for bc in [False, True]:
    for num_params, label_site in [(3, None), (2, 0), (2, input_size)]:
        # MPS(input_size, output_dim, bond_dim, d=2, label_site=None,
        #     periodic_bc=False, parallel_eval=False, adaptive_mode=False, 
        #     cutoff=1e-10, merge_threshold=1000)
        mps_module = MPS(input_size, output_dim, bond_dim, periodic_bc=bc,
                         label_site=label_site)
        assert len(list(mps_module.parameters())) == num_params

        output = mps_module(input_data)
        assert list(output.shape) == [batch_size, output_dim]
Ejemplo n.º 3
0
num_test = 1000
batch_size = 100
num_epochs = 20
learn_rate = 1e-4
l2_reg = 0.

# Initialize the MPS module
mps = MPS(input_dim=28**2,
          output_dim=10,
          bond_dim=bond_dim,
          adaptive_mode=adaptive_mode,
          periodic_bc=periodic_bc)

# Set our loss function and optimizer
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(),
                             lr=learn_rate,
                             weight_decay=l2_reg)

# Get the training and test sets
transform = transforms.ToTensor()
train_set = datasets.FashionMNIST('./fashionmnist',
                                  download=True,
                                  transform=transform)
test_set = datasets.FashionMNIST('./fashionmnist',
                                 download=True,
                                 transform=transform,
                                 train=False)

# train_set = datasets.MNIST('./mnist', download=True, transform=transform)
# test_set = datasets.MNIST('./mnist', download=True, transform=transform,
Ejemplo n.º 4
0
input_data = torch.randn([batch_size, input_size])

# For both open and periodic boundary conditions, place the label site in
# different locations and check that the basic behavior is correct
for bc in [False, True]:
    for num_params, label_site in [(8, None), (5, 0), (6, 1), (5, input_size),
                                   (6, input_size - 1)]:
        mps_module = MPS(input_size,
                         output_dim,
                         bond_dim,
                         periodic_bc=bc,
                         label_site=label_site,
                         adaptive_mode=True,
                         merge_threshold=merge_threshold)
        if not len(list(mps_module.parameters())) == num_params:
            print(len(list(mps_module.parameters())))
            print(num_params)
        assert len(list(mps_module.parameters())) == num_params
        assert mps_module.linear_region.offset == 0

        for _ in range(6):
            output = mps_module(input_data)
            assert list(output.shape) == [batch_size, output_dim]

        # At this point we should have flipped our offset from 0 to 1, but are
        # on the threshold so that the next call will flip offset back to 0
        assert len(list(mps_module.parameters())) == num_params
        assert mps_module.linear_region.offset == 1

        output = mps_module(input_data)
Ejemplo n.º 5
0
print("merge_threshold =", merge_threshold)
print("cutoff =", cutoff)
print("Using device:", device)
print()
print("path =", path)
print()
sys.stdout.flush()

# Initialize the MPS module
mps = MPS(input_dim=input_dim, output_dim=output_dim, bond_dim=bond_dim, 
          adaptive_mode=adaptive_mode, periodic_bc=periodic_bc,
          merge_threshold=merge_threshold, path=path)

# Set loss function and optimizer
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mps.parameters(), lr=lr, weight_decay=l2_reg)

# Miscellaneous initialization
torch.set_default_tensor_type('torch.FloatTensor')
torch.manual_seed(0)
start_time = time.time()

# Get the training and test sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(torchmps_dir+'/mnist', download=True, 
                           transform=transform)
test_set = datasets.MNIST(torchmps_dir+'/mnist', download=True, 
                          transform=transform, train=False)

# Put MNIST data into Pytorch tensors
train_inputs = torch.stack([data[0].view(input_dim) for data in train_set])
Ejemplo n.º 6
0
def mps_train(loaders, num_channel):
    # Initialize the MPS module
    mps = MPS(input_dim=dim**2,
              output_dim=10,
              bond_dim=bond_dim,
              adaptive_mode=adaptive_mode,
              periodic_bc=periodic_bc)

    # Set our loss function and optimizer
    loss_fun = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(mps.parameters(),
                                 lr=learn_rate,
                                 weight_decay=l2_reg)

    if num_channel == 3:
        size = [num_channel, batch_size, dim**2]
    else:
        size = [batch_size, dim**2]

    train_acc = []
    test_acc = []
    ave_loss = []
    run_time = []

    for epoch_num in range(1, num_epochs + 1):
        running_loss = 0.
        running_acc = 0.

        for inputs, labels in loaders['train']:
            inputs, labels = inputs.view(size), labels.data

            # Call our MPS to get logit scores and predictions
            scores = mps(inputs)
            _, preds = torch.max(scores, 1)

            # Compute the loss and accuracy, add them to the running totals
            loss = loss_fun(scores, labels)
            with torch.no_grad():
                accuracy = torch.sum(preds == labels).item() / batch_size
                running_loss += loss
                running_acc += accuracy

            # Backpropagate and update parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"### Epoch {epoch_num} ###")
        print(
            f"Average loss:           {running_loss / num_batches['train']:.4f}"
        )
        train_acc = train_acc + [running_acc / num_batches['train']]
        print(
            f"Average train accuracy: {running_acc / num_batches['train']:.4f}"
        )
        ave_loss = ave_loss + [running_loss / num_batches['train']]

        # Evaluate accuracy of MPS classifier on the test set
        with torch.no_grad():
            running_acc = 0.

            for inputs, labels in loaders['test']:
                inputs, labels = inputs.view(size), labels.data

                # Call our MPS to get logit scores and predictions
                scores = mps(inputs)
                _, preds = torch.max(scores, 1)
                running_acc += torch.sum(preds == labels).item() / batch_size

        print(
            f"Test accuracy:          {running_acc / num_batches['test']:.4f}")
        test_acc = test_acc + [running_acc / num_batches['test']]
        print(f"Runtime so far:         {int(time.time()-start_time)} sec\n")
        run_time = run_time + [int(time.time() - start_time)]
#         print(test_acc)

    return run_time, ave_loss, train_acc, test_acc