Пример #1
0
 def train_worker():
     flags = args_parse.parse_common_options(datadir='/tmp/mnist-data',
                                             batch_size=16,
                                             momentum=0.5,
                                             lr=0.01,
                                             num_epochs=10)
     flags.fake_data = True
     flags.profiler_port = port
     test_profile_mp_mnist.train_mnist(flags,
                                       worker_started=worker_started)
Пример #2
0
import args_parse

FLAGS = args_parse.parse_common_options(datadir='/tmp/mnist-data',
                                        batch_size=128,
                                        momentum=0.5,
                                        lr=0.01,
                                        target_accuracy=98.0,
                                        num_epochs=18)

import os
import pprint
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt
import torch_xla.test.test_utils as test_utils


class MNIST(nn.Module):
    def __init__(self):
        super(MNIST, self).__init__()
Пример #3
0
    '--lr_scheduler_type': {
        'type': str,
    },
    '--lr_scheduler_divide_every_n_epochs': {
        'type': int,
    },
    '--lr_scheduler_divisor': {
        'type': int,
    },
}

FLAGS = args_parse.parse_common_options(
    datadir='/tmp/imagenet',
    batch_size=None,
    num_epochs=None,
    momentum=None,
    lr=None,
    target_accuracy=None,
    opts=MODEL_OPTS.items(),
)

import os
import schedulers
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch_xla
Пример #4
0
import args_parse

MODEL_OPTS = {
    '--use_torchvision': {
        'default': False,
        'type': bool,
    },
}
FLAGS = args_parse.parse_common_options(datadir='/tmp/cifar-data',
                                        batch_size=128,
                                        num_epochs=20,
                                        momentum=0.9,
                                        lr=0.1,
                                        target_accuracy=80.0,
                                        opts=MODEL_OPTS.items())

from common_utils import TestCase, run_tests
import os
import shutil
import test_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_xla
import torch_xla_py.data_parallel as dp
import torch_xla_py.utils as xu
import torch_xla_py.xla_model as xm
import torchvision
import torchvision.transforms as transforms
import unittest
"""Fork of test_train_mp_mnist.py to demonstrate how to profile workloads."""
import args_parse

FLAGS = args_parse.parse_common_options(
    datadir="/tmp/mnist-data",
    batch_size=128,
    momentum=0.5,
    lr=0.01,
    target_accuracy=98.0,
    num_epochs=18,
    profiler_port=80,
)

import os
import shutil
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils