def train_worker(): flags = args_parse.parse_common_options(datadir='/tmp/mnist-data', batch_size=16, momentum=0.5, lr=0.01, num_epochs=10) flags.fake_data = True flags.profiler_port = port test_profile_mp_mnist.train_mnist(flags, worker_started=worker_started)
import args_parse FLAGS = args_parse.parse_common_options(datadir='/tmp/mnist-data', batch_size=128, momentum=0.5, lr=0.01, target_accuracy=98.0, num_epochs=18) import os import pprint import shutil import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl import torch_xla.utils.utils as xu import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp from torch_xla.experimental import pjrt import torch_xla.test.test_utils as test_utils class MNIST(nn.Module): def __init__(self): super(MNIST, self).__init__()
'--lr_scheduler_type': { 'type': str, }, '--lr_scheduler_divide_every_n_epochs': { 'type': int, }, '--lr_scheduler_divisor': { 'type': int, }, } FLAGS = args_parse.parse_common_options( datadir='/tmp/imagenet', batch_size=None, num_epochs=None, momentum=None, lr=None, target_accuracy=None, opts=MODEL_OPTS.items(), ) import os import schedulers import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms import torch_xla
import args_parse MODEL_OPTS = { '--use_torchvision': { 'default': False, 'type': bool, }, } FLAGS = args_parse.parse_common_options(datadir='/tmp/cifar-data', batch_size=128, num_epochs=20, momentum=0.9, lr=0.1, target_accuracy=80.0, opts=MODEL_OPTS.items()) from common_utils import TestCase, run_tests import os import shutil import test_utils import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch_xla import torch_xla_py.data_parallel as dp import torch_xla_py.utils as xu import torch_xla_py.xla_model as xm import torchvision import torchvision.transforms as transforms import unittest
"""Fork of test_train_mp_mnist.py to demonstrate how to profile workloads.""" import args_parse FLAGS = args_parse.parse_common_options( datadir="/tmp/mnist-data", batch_size=128, momentum=0.5, lr=0.01, target_accuracy=98.0, num_epochs=18, profiler_port=80, ) import os import shutil import sys import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import torch_xla import torch_xla.debug.metrics as met import torch_xla.distributed.parallel_loader as pl import torch_xla.utils.utils as xu import torch_xla.core.xla_model as xm import torch_xla.debug.profiler as xp import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.test.test_utils as test_utils