Exemple #1
0
import test_utils

FLAGS = test_utils.parse_common_options(
    datadir='/tmp/mnist-data',
    batch_size=128,
    momentum=0.5,
    lr=0.01,
    target_accuracy=98.0,
    num_epochs=18)

from common_utils import TestCase, run_tests
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla_py.data_parallel as dp
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import unittest


class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
Exemple #2
0
import test_utils

FLAGS = test_utils.parse_common_options(datadir='../cifar-data',
                                        batch_size=125,
                                        num_epochs=254,
                                        momentum=0.9,
                                        lr=0.8)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla_py.data_parallel as dp
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import torchvision
import torchvision.transforms as transforms

# Import utilities and models
from torch.optim.lr_scheduler import MultiStepLR
from utilities import Cutout, RandomPixelPad, CosineAnnealingRestartsLR
from models import WRN_McDonnell


def train_cifar():
    print('==> Preparing data..')

    transform_train = transforms.Compose([
        transforms.Lambda(lambda x: RandomPixelPad(x, padding=4)),
Exemple #3
0
    'vgg16_bn',
    'vgg19',
    'vgg19_bn'
]

MODEL_OPTS = {
    '--model': {
        'choices': SUPPORTED_MODELS,
        'default': 'resnet50',
    }
}
FLAGS = test_utils.parse_common_options(
    datadir='/tmp/imagenet',
    batch_size=None,
    num_epochs=None,
    momentum=None,
    lr=None,
    target_accuracy=None,
    opts=MODEL_OPTS.items(),
)

from common_utils import TestCase, run_tests
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch_xla
import torch_xla_py.data_parallel as dp
Exemple #4
0
import test_utils

FLAGS = test_utils.parse_common_options(datadir='/tmp/mnist-data',
                                        batch_size=256,
                                        target_accuracy=98.0)

from common_utils import TestCase, run_tests
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla_py.xla_model as xm
import unittest


class MNIST(nn.Module):
    def __init__(self):
        super(MNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.bn1 = nn.BatchNorm2d(10)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = self.bn1(x)
Exemple #5
0
import test_utils

FLAGS = test_utils.parse_common_options(
    datadir='/tmp/imagenet', batch_size=128, num_epochs=15,
    target_accuracy=0.0)

from common_utils import TestCase, run_tests
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch_xla
import torch_xla_py.xla_model as xm
import unittest


def _cross_entropy_loss_eval_fn(cross_entropy_loss):
    def eval_fn(output, target):
        loss = cross_entropy_loss(output, target).item()
        # Get the index of the max log-probability.
        pred = output.max(1, keepdim=True)[1]
        correct = pred.eq(target.view_as(pred)).sum().item()
        return loss, correct

    return eval_fn

Exemple #6
0
import test_utils

FLAGS = test_utils.parse_common_options(datadir='/tmp/cifar-data',
                                        batch_size=128,
                                        num_epochs=20,
                                        momentum=0.9,
                                        lr=0.2,
                                        target_accuracy=80.0)

from common_utils import TestCase, run_tests
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla_py.data_parallel as dp
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import torchvision
import torchvision.transforms as transforms
import unittest


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes,
                               planes,
Exemple #7
0
import test_utils

FLAGS = test_utils.parse_common_options(datadir='/tmp/cifar-data',
                                        batch_size=128,
                                        num_epochs=15,
                                        target_accuracy=80.0)

from common_utils import TestCase, run_tests
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import torchvision
import torchvision.transforms as transforms
import unittest


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,