def run_benchmark(args, pos_args): devices = xm.get_xla_supported_devices(max_devices=args.max_devices) shape = [int(x) for x in args.shape.split(',')] send_list = [] for i in range(0, len(devices)): mb = [] for j in range(0, args.prefetch): mb.append(torch.randn(*shape)) send_list.append(mb) def threadfn(i): device = devices[i] xdevices = [device] * len(send_list[i]) for n in range(0, args.test_count): with xu.TimedScope(msg='Send[{}][{}]: '.format(i, n), printfn=print): _ = torch_xla._XLAC._xla_tensors_from_aten(send_list[i], xdevices) threads = [] for i in range(0, len(devices)): t = threading.Thread(target=threadfn, args=(i,)) t.start() threads.append(t) for t in threads: t.join() print(torch_xla._XLAC._xla_metrics_report())
def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for data, target in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed( lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed( lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel( torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def inference(): model = model_from_name(args.model_name) devices = (xm.get_xla_supported_devices( max_devices=args.num_cores) if args.num_cores != 0 else []) model.load_state_dict(torch.load(args.weight_file)) model_parallel = dp.DataParallel(model, device_ids=devices) patient = extract_patient(args.csv_file_path, return_label=False) ## For all predictions files must be multiple by num_cores*batch_size extra = args.num_cores * args.batch_size - len( patient) % args.num_cores * args.batch_size patient = np.concatenate([patient, patient[:extra]]) ds = RSNADataset(patient, path=args.path, transform=train_transforms) loader = D.DataLoader(ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) result = model_parallel(infer_loop, loader) result = np.array(result) score, patientid = [result[:, i] for i in range(result.shape[-1])] score = np.concatenate(score) patientid = np.concatenate(patientid) prediction_to_df(score, patientid, args.subm_file)
def _is_device_tpu() -> bool: """ Check if TPU devices are available Return: A boolean value indicating if TPU devices are available """ return len(xm.get_xla_supported_devices("TPU")) > 0
def test__xla_dist_model_run_parallel_n_threads_without_sync(): # tests issue : https://github.com/pytorch/ignite/issues/1096 import torch_xla.core.xla_model as xm from joblib import delayed, Parallel devices = xm.get_xla_supported_devices() folds = 1 d = 0 if len(devices) > 5: folds = 5 d = 1 Parallel(n_jobs=folds, backend="threading")(delayed(main_fold)(i + d) for i in range(folds))
def _is_device_tpu() -> bool: """Check if TPU devices are available. Return: A boolean value indicating if TPU devices are available """ # For the TPU Pod training process, for example, if we have # TPU v3-32 with 4 VMs, the world size would be 4 and as # we would have to use `torch_xla.distributed.xla_dist` for # multiple VMs and TPU_CONFIG won't be available, running # `xm.get_xla_supported_devices("TPU")` won't be possible. return (xm.xrt_world_size() > 1) or bool( xm.get_xla_supported_devices("TPU"))
def test(self): devices = [torch.device(x) for x in xm.get_xla_supported_devices()] A = 3.11 B = 4.09 batch_size = 128 * len(devices) gen = xu.FnDataGenerator( lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10) para_loader = pl.ParallelLoader(gen, devices) for device in devices: loader = para_loader.per_device_loader(device) for data, target in loader: self.assertEqual(data.device, device) self.assertEqual(target.device, device)
def run_thread_per_device(rank: int, processes: int, fn: Callable[..., R]) -> Dict[int, R]: """Runs `fn` in a separate thread on each visible device. Args: rank: rank of current process processes: number of processes on this host fn: Function to run on all devices Returns: Dict of the form {thread_rank: return_value}, where return_value is the result of calling `fn`. """ if device_type() == 'TPU': configure_tpu_topology(rank, processes) xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices()) threads = len(xm.get_xla_supported_devices()) def _thread_fn(fn, device_index): @functools.wraps(fn) def wrapper(*args, **kwargs): # Assumes same number of threads per process set_global_ordinal(rank * threads + device_index) set_local_ordinal(rank * threads + device_index) return fn(*args, **kwargs) return wrapper with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: futures = {executor.submit(_thread_fn(fn, i)): i for i in range(threads)} results = { futures[f]: f.result() for f in concurrent.futures.as_completed(futures) } return results
def train(): set_seeds() logging.info('Loading masks...') with open(args.json_file, 'r') as f: masks = json.load(f) # for example use only 200 images filename = list(masks.keys())[:200] global devices, num_steps_per_epoch devices = (xm.get_xla_supported_devices( max_devices=args.num_cores) if args.num_cores != 0 else []) logging.info('Start training model') if args.model_name == 'deeplabv3_resnet50': m = torchvision.models.segmentation.deeplabv3_resnet50(False) else: m = torchvision.models.segmentation.fcn_resnet50(False) m.classifier[-1] = torch.nn.Conv2d(m.classifier[-1].in_channels, 46, 1) # wrapped for parallel training model = dp.DataParallel(m, device_ids=devices) ds = FashionDataset(filename, masks, path=args.data_path, transform=train_transform, size=(256, 256)) loader = D.DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_worker) num_steps_per_epoch = len(loader) // len(devices) for epoch in range(1, args.epochs + 1): train_loss = model(train_loop_fn, loader) train_loss = np.array(train_loss).mean() logging.info('[Epoch {:3d}] Train loss: {:.3f}'.format( epoch, train_loss)) # Save weights state_dict = model.models[0].to('cpu').state_dict() torch.save(state_dict, args.save_file) logging.info('') logging.info('Model saved\n')
def test(self): devices = xm.get_xla_supported_devices() A = 3.11 B = 4.09 batch_size = 128 * len(devices) gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10) para_loader = pl.ParallelLoader(gen, batch_size, devices) for x, (data, target) in para_loader: for device in devices: dx = para_loader.to(data, device) self.assertEqual(dx.device, torch.device(device))
def test(self): devices = xm.get_xla_supported_devices() for device in reversed(devices): t = _gen_tensor(8, 12) tto = t.to(device=torch.device(device)) self.assertEqual(tto.device, torch.device(device)) t = _gen_tensor(8, 12).to(device=torch.device(devices[0])) for device in devices[1:]: tto = t.to(device=torch.device(device)) self.assertEqual(tto.device, torch.device(device)) for i in range(0, len(devices) - 1): dev0 = devices[i] dev1 = devices[i + 1] t0 = torch.zeros(4, 4, device=torch.device(dev0)) t1 = t0.to(device=torch.device(dev1)) t0 = t0 + torch.ones_like(t0, device=torch.device(dev0)) t1 = t1 + torch.ones_like(t1, device=torch.device(dev1)) self.assertEqual(t0.cpu(), t1.cpu())
def __init__(self, network, device_ids=None): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = [str(x) for x in device_ids] self._native_run = False self._models = [] self._contexts = [] module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: device_module = deepcopy(module).to(device=torch.device(device)) self._models.append(device_module) self._contexts.append(Context(torch.device(device))) if not self._models: # No XLA device, push a vanilla network in. device = self._get_model_device(module) self._models.append(module) self._device_ids.append(device) self._contexts.append(Context(torch.device(device))) self._native_run = True
def xla_device(n: Optional[int] = None, devkind: Optional[str] = None) -> torch.device: """Returns an XLA device. Args: n: Index of XLA device within visibible devices. If not set, use local ordinal (default 0) to select an addressable device. devkind: Type of device to return. Should match `device_type()`. Returns: A `torch.device` representing an XLA device. """ devices = xm.get_xla_supported_devices(devkind=devkind) device_index = n or (local_ordinal(default=0) % addressable_device_count()) if device_index > len(devices): raise IndexError('Device index {} out of range in {}'.format( device_index, devices)) torch_xla._XLAC._xla_set_default_device(devices[device_index]) return torch.device(devices[device_index])
def train(): set_seeds() global num_steps_per_epoch ## Create and wrap model ## model = model_from_name(args.model_name) devices = ( xm.get_xla_supported_devices(max_devices=args.num_cores) if args.num_cores != 0 else []) model_parallel = dp.DataParallel(model, device_ids=devices) logging.info('Model {} loaded and wrapped'.format(args.model_name)) logging.info('') ## Create dataset and loader patient, label = extract_patient(args.csv_file_path, return_label=True) if args.use_image: ds = RSNAImages(patient, label, path=args.path, transform=train_transforms) else: ds = RSNADataset(patient, label, path=args.path, transform=train_transforms) loader = D.DataLoader(ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) logging.info('Dataset created\n') logging.info('Start training model\n') num_steps_per_epoch = len(loader) // len(devices) for epoch in range(1, args.num_epochs + 1): start_time = time.time() model_parallel(train_loop_fn, loader) logging.info('') logging.info('Epoch training time: {:.2f} minutes\n'.format((time.time() - start_time)/60**1)) # Save weights state_dict = model_parallel.models[0].to('cpu').state_dict() torch.save(state_dict, args.save_pht) logging.info('') logging.info('Model saved') logging.info('')
def test_xla_sharded_tensor(self): # Test XLAShardedTensor basic properties # Simple 1-D sharding num_devices = len(xm.get_xla_supported_devices("TPU")) mesh_shape = (1, num_devices) partition_spec = (1, ) t1 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) t1_sharded = XLAShardedTensor(t1, mesh_shape, partition_spec) t2 = torch.tensor([2.0, 3.0], dtype=torch.float, device=xm.xla_device()) t3 = torch.add(t1_sharded, t2) assert isinstance( t3, XLAShardedTensor), "Sharded ops should return XLAShardedTensor." assert t3.size() == t1.size( ), "Sharded output should return unpartitioned tensor size." device_ids = np.array(range(num_devices)) device_assignment = list(device_ids.reshape(mesh_shape))
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) if FLAGS.metrics_debug: print(met.metrics_report()) return accuracy
def train_unet(): # logging setup logger_name = 'train_logger' logger = initializeLogging(os.path.join(logdir, 'train_history.txt'), logger_name) # checkpointing setup checkpoint_frequency = log_steps torch.manual_seed(1) ''' train_dataset = datasets.MNIST( datadir, train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) train_dataset_len = len(train_dataset) test_dataset = datasets.MNIST( datadir, train=False, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])) ''' img_dim = 1024 normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) train_dataset_len = len(train_dataset) resize_dim = max(img_dim, 256) test_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, drop_last=drop_last, shuffle=False if train_sampler else True, num_workers=num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, sampler=test_sampler, drop_last=drop_last, shuffle=False, num_workers=num_workers) maxItr = num_epochs * len(train_loader) // train_loader.batch_size + 1 devices = ( xm.get_xla_supported_devices( max_devices=num_cores) if num_cores != 0 else []) # Scale learning rate to num cores lr = 0.0001 lr = lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(UNet, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.Adam(model.parameters(), lr=lr)) tracker = xm.RateTracker() model.train() print('# of iterations: {}'.format(maxItr)) logger.info('# of iterations: {}'.format(maxItr)) optimizer.zero_grad() for x, (data, target) in enumerate(loader): data = target[0].permute(0,3,1,2) target = target[1] output = model(data) loss = loss_fn(output, target.long()) #_, preds = torch.max(output, 1) loss.backward() # backprop every log_step iterations if x % log_steps == 0: xm.optimizer_step(optimizer) optimizer.zero_grad() tracker.add(batch_size) # compute the confusion matrix and IoU #print(preds.shape) #print(target.shape) #val_conf = np.zeros((num_classes, num_classes)) #val_conf = val_conf + confusion_matrix( # target[target >= 0].view(-1).cpu().numpy(), # preds[target >= 0].view(-1).cpu().numpy()) #pos = np.sum(val_conf, 1) #res = np.sum(val_conf, 0) #tp = np.diag(val_conf) #iou = np.mean(tp / np.maximum(1, pos + res - tp)) #logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) print('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) if x % log_steps == 0: logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for data, target in loader: data = target[0].permute(0,3,1,2) target = target[1] output = model(data) #pred = output.max(1, keepdim=True)[1].float() _, preds = torch.max(output, 1) preds = preds.float() correct += preds.eq(target.view_as(preds)).sum().item() total_samples += target.shape[1]**2 print('device: {}, Running Accuracy: {}'.format(device, correct/total_samples)) accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) logger.info('TEST: device: {}, accuracy: {}'.format(device, accuracy)) return accuracy accuracy = 0.0 writer = test_utils.get_summary_writer(logdir) num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // ( batch_size * num_devices) print('total epochs: {}'.format(num_epochs)) for epoch in range(1, num_epochs + 1): print(epoch) print(train_loader) model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) logger.info('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch if metrics_debug: print(met.metrics_report()) logger.info(met.metrics_report()) logger.info('saving checkpoint. epoch: {}'.format(epoch)) torch.save(model_parallel, os.path.join(logdir,'model_parallel_chkpt.pt')) logger.info('checkpoint saved. epoch: {}'.format(epoch)) test_utils.close_summary_writer(writer) return accuracy
def train_mnist(FLAGS): DTYPE = torch.float32 torch.manual_seed(1) dims = ( FLAGS.batch_size, 1, 784, ) train_dataset_len = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else 60000 train_loader = xu.SampleGenerator( data=( torch.ones(dims, dtype=DTYPE,), torch.ones( FLAGS.batch_size, dtype=torch.int64 if not _MSE_LOSS else DTYPE, ), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.ones(dims, dtype=DTYPE,), torch.ones( FLAGS.batch_size, dtype=torch.int64 if not _MSE_LOSS else DTYPE, ), ), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size(), ) devices = ( xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [] ) """ Non multi-processing """ # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) model = MNIST(FLAGS) model_parallel = dp.DataParallel( model, device_ids=devices, ) writer = None if xm.is_master_ordinal(): writer = test_utils.get_summary_writer(FLAGS.logdir) # Just some step closure output def train_output_fn(outputs, ctx, args, tracker): if ctx.step > 0 and args.log_steps and ctx.step % args.log_steps == 0: now_time = time.time() if hasattr(ctx, 'start_time') and ctx.start_time: per_step_time = (now_time - ctx.start_time) / ( ctx.step - ctx.last_step_timed ) steps_per_second = 1 / per_step_time print( f'[{xm.get_ordinal()}] Round-trip step time: ' f'{per_step_time} seconds, steps per second: {steps_per_second}' ) if tracker: _train_update( device=device, step=ctx.step, loss=outputs[0], tracker=tracker, epoch=epoch, writer=writer, ) print(f'BEGIN Train step {ctx.step}') ctx.start_time = time.time() ctx.last_step_timed = ctx.step else: ctx.start_time = time.time() ctx.last_step_timed = ctx.step ctx.step += 1 def train_loop_fn(model, loader, device=None, context=None): lr_adder = 0.0 if _MSE_LOSS: loss_fn = nn.MSELoss() else: loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr + lr_adder, momentum=FLAGS.momentum, ), ) tracker = xm.RateTracker() model.train() def train_inner_loop_fn(batch, ctx): step = ctx.step print(f'Step {step}') data = batch[0] target = batch[1] optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step( optimizer, barrier=False, ) if ( FLAGS.log_steps != 0 and ( FLAGS.log_steps == 1 or (step > 0 and step % FLAGS.log_steps == 0) ) ): xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer), ) if step == 0: xm.master_print(f"End TRAIN step {step}") ctx.step += 1 return [loss] step = 0 # Train print('Starting new epoch train loop... (epoch={epoch})') for step, (data, target) in enumerate(loader): if step % FLAGS.step_print_interval == 0: xm.master_print(f"Begin TRAIN Step: {step}") context.step = step if not FLAGS.use_autograph: outputs = train_inner_loop_fn((data, target), context) else: outputs = ptwse.flow.runner.maybe_run_converted( train_inner_loop_fn, (data, target), context ) xm.master_print(f"Saving model...") _save_checkpoint(FLAGS, device, None, model, is_epoch=True) xm.master_print(f"Model saved") def test_loop_fn(model, loader, device, context): print("***********************") print("ENTERING TEST FUNCTION") print("***********************") print('Evaluating...') total_samples = 0 correct = 0 model.eval() for step, (data, target) in enumerate(loader): if step >= FLAGS.test_max_step: break output = model(data) pred = output.max(1, keepdim=True)[1] if FLAGS.mp: correct += pred.eq(target.view_as(pred)).sum() else: correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] if FLAGS.mp: this_accuracy = 100.0 * correct.item() / total_samples print("CALLING: mesh_reduce('test_accuracy')") this_accuracy = xm.mesh_reduce( 'test_accuracy', this_accuracy, np.mean ) print("BACK FROM: mesh_reduce('test_accuracy')") else: this_accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, this_accuracy) print("***********************") print("LEAVING TEST FUNCTION") print("***********************") return this_accuracy # # Set up for # accuracy = 0.0 num_devices = ( len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 ) if not FLAGS.steps_per_epoch: num_training_steps_per_epoch = train_dataset_len // ( FLAGS.batch_size * num_devices ) else: num_training_steps_per_epoch = FLAGS.steps_per_epoch max_accuracy = 0.0 # # Epoch loop # for epoch in range(1, FLAGS.num_epochs + 1): # # Train # device = xm.xla_device() ctx = dp.Context(device=device) ctx.tracker = xm.RateTracker() ctx.step = 0 train_loop_fn(model, train_loader, device, ctx) # # Test # if FLAGS.run_test: with ptwse.scope.proxy_disabled(disabled=FLAGS.test_off_proxy): accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print( 'Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy) ) global_step = (epoch - 1) * num_training_steps_per_epoch max_accuracy = max(accuracy, max_accuracy) test_utils.write_to_summary( writer, global_step, dict_to_write={'Accuracy/test': accuracy}, write_xla_metrics=True, ) if FLAGS.metrics_debug: xm.master_print(met.metrics_report()) test_utils.close_summary_writer(writer) xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy)) return max_accuracy
import torch import torch_xla import torch_xla.core.xla_model as xm print("Starting...") print("Supported xla devices") print(xm.get_xla_supported_devices()) t = torch.randn(2, 2, device=xm.xla_device()) print(t.device) print(t)
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) if lr_scheduler: lr_scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_dataset_len = 60000 # Number of images in MNIST dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_dataset_len = len(train_dataset) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=FLAGS.drop_last, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, drop_last=FLAGS.drop_last, shuffle=False, num_workers=FLAGS.num_workers) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for data, target in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = test_utils.get_summary_writer(FLAGS.logdir) num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(met.metrics_report()) test_utils.close_summary_writer(writer) return accuracy
def train_imagenet(): print("==> Preparing data..") img_dim = get_model_property("img_dim") if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64), ), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(), ) test_loader = xu.SampleGenerator( data=( torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64), ), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(), ) else: normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "train"), transforms.Compose( [ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ), ) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, "val"), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose( [ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ] ), ) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True, ) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False, ) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, drop_last=True, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, drop_last=False, shuffle=False, num_workers=FLAGS.num_workers, ) torch.manual_seed(42) devices = ( xm.get_xla_supported_devices(max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else ["cpu"] ) print("use tpu devices", devices) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property("model_fn") model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( "optimizer", lambda: optim.SGD( model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=1e-4, ), ) lr_scheduler = context.getattr_or( "lr_scheduler", lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, "lr_scheduler_type", None), scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None), scheduler_divide_every_n_epochs=getattr( FLAGS, "lr_scheduler_divide_every_n_epochs", None ), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None, ), ) tracker = xm.RateTracker() model.train() total_samples = 0 correct = 0 top5_accuracys = 0 losses = 0 for x, (data, target) in loader: optimizer.zero_grad() output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5) loss = loss_fn(output, target) loss.backward() losses += loss.item() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: print( "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format( str(device), x, loss.item(), (100.0 * correct / total_samples).item(), tracker.rate(), tracker.global_rate(), time.asctime(), ) ) if lr_scheduler: lr_scheduler.step() return ( losses / (x + 1), (100.0 * correct / total_samples).item(), (top5_accuracys / (x + 1)).item(), ) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 top5_accuracys = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] top5_accuracys += topk_accuracy(output, target, topk=5).item() accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy, top5_accuracys accuracy = 0.0 writer = SummaryWriter(FLAGS.logdir) if FLAGS.logdir else None num_devices = len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) max_accuracy = 0.0 print("train_loader_len", len(train_loader), num_training_steps_per_epoch) print("test_loader_len", len(test_loader)) for epoch in range(1, FLAGS.num_epochs + 1): global_step = (epoch - 1) * num_training_steps_per_epoch # Train evaluate metrics = model_parallel(train_loop_fn, train_loader) losses, accuracies, top5_accuracys = zip(*metrics) loss = mean(losses) accuracy = mean(accuracies) top5_accuracy = mean(top5_accuracys) print( "Epoch: {} (Train), Loss {}, Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format( epoch, loss, accuracy, top5_accuracy ) ) test_utils.add_scalar_to_summary(writer, "Loss/train", loss, global_step) test_utils.add_scalar_to_summary( writer, "Top-1 Accuracy/train", accuracy, global_step ) test_utils.add_scalar_to_summary( writer, "Top-5 Accuracy/train", top5_accuracy, global_step ) # Test evaluate metrics = model_parallel(test_loop_fn, test_loader) accuracies, top5_accuracys = zip(*metrics) top5_accuracys = sum(top5_accuracys) top5_accuracy = top5_accuracys / len(test_loader) accuracy = mean(accuracies) print( "Epoch: {} (Valid), Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format( epoch, accuracy, top5_accuracy ) ) test_utils.add_scalar_to_summary( writer, "Top-1 Accuracy/test", accuracy, global_step ) test_utils.add_scalar_to_summary( writer, "Top-5 Accuracy/test", top5_accuracy, global_step ) if FLAGS.metrics_debug: print(met.metrics_report()) if accuracy > max_accuracy: max_accuracy = max(accuracy, max_accuracy) torch.save( model_parallel.models[0].to("cpu").state_dict(), f"./reports/resnet50_model-{epoch}.pt", ) model_parallel.models[0].to(devices[0]) test_utils.close_summary_writer(writer) print("Max Accuracy: {:.2f}%".format(accuracy)) return max_accuracy
!pip3 install -r "/content/drive/My Drive/htfl/Freeze.txt" import os print(os.environ["COLAB_TPU_ADDR"]) import torch # imports the torch_xla package import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl import torch_xla.distributed.data_parallel as dp import torch_xla.distributed.xla_multiprocessing as xmp num_cores = 10 devices = ( xm.get_xla_supported_devices( max_devices=num_cores) if num_cores != 0 else []) print("Devices: {}".format(devices)) os.chdir('/content/drive/My Drive/htfl/') !rm train.py -f import os import sys from google.colab import files uploaded = files.upload() os.chdir('/content/drive/My Drive/htfl/config') !rm bert_config.json -f import os import sys from google.colab import files uploaded = files.upload()
args.b, settings.MEAN, settings.STD, ) Flag = {} Flag['lr'] = args.lr Flag['epoch'] = args.e Flag['warm'] = args.warm Flag['batch_size'] = args.b Flag['milestones'] = settings.MILESTONES Flag['ignore_idx'] = train_data_loader.dataset.ignore_index len(train_data_loader.dataset) net = UNet(3, train_data_loader.dataset.class_num) devices = (xm.get_xla_supported_devices()) print(devices) net = dp.DataParallel(net, device_ids=devices) iter_per_epoch = len(train_data_loader) / 8 best_iou = 0 for epoch in range(1, args.e + 1): print('training epoch {}'.format(epoch)) t1 = time.time() net(train_loop_fn, train_data_loader) print(time.time() - t1) #result = net(test_loop_fn, test_data_loader)
def measure_tpu(warmups, steps, h_emb, h_indices, h_offsets, args): import torch_xla import torch_xla.core.xla_model as xm import os tsize = int(os.environ.get("MODEL_PARTITION_SIZE", 3000000)) def syncTPU(tensor): torch_xla._XLAC._xla_sync_multi([tensor], devices=[], wait=True, sync_xla_data=True) alldev = xm.get_xla_supported_devices() allrealdev = xm.xla_real_devices(alldev) print("Found {0} devices: {1}".format(len(allrealdev), allrealdev)) dev = xm.xla_device() if (args.features > tsize): if args.usexlabag: tsplit = torch.split(h_emb.embtable.weight, tsize, dim=0) else: tsplit = torch.split(h_emb.weight, tsize, dim=0) tsplit = list(tsplit) for i, chunk in enumerate(tsplit): tsplit[i] = chunk.to(dev) t = nn.Parameter(torch.ones(10, 10)) if args.usexlabag: h_emb.embtable.weight = t t_emb = h_emb.to(dev) tsplit = torch.cat(tsplit) t_emb.embtable.weight = nn.Parameter(tsplit) print("Xla EMB weight shape: ", t_emb.embtable.weight.shape, " on device: ", str(dev)) else: h_emb.weight = t t_emb = h_emb.to(dev) tsplit = torch.cat(tsplit) t_emb.weight = nn.Parameter(tsplit) print("EMB weight shape: ", t_emb.weight.shape, " on device: ", str(dev)) else: t_emb = h_emb.to(dev) t_indices = h_indices.to(dev) t_offsets = h_offsets.to(dev) emb_times = 0.0 start1 = time.perf_counter() for i in range(warmups + steps): start = time.perf_counter() results = t_emb(t_indices, t_offsets) syncTPU(results) end = time.perf_counter() print("Time: {0:.6f} ".format(end - start)) if (i >= warmups): emb_times += end - start end1 = time.perf_counter() return end1 - start1, emb_times, results
def train_unet(): print('==> Preparing data..') img_dim = 1024 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) train_dataset_len = len(train_dataset) resize_dim = max(img_dim, 256) test_dataset = satelliteDataSet( os.path.join(datadir, 'train'), transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=num_workers) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices(max_devices=num_cores)) torchvision_model = UNet() model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=1e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=lr_scheduler_type, scheduler_divisor=lr_scheduler_divisor, scheduler_divide_every_n_epochs= lr_scheduler_divide_every_n_epochs, num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() data = data.permute(0, 3, 1, 2) output = model(data) print('passed through model') loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(batch_size) if x % log_steps == 0: print( 'device: {}, x: {}, loss: {}, tracker: {}, tracker_global: {} ' .format(device, x, loss.item(), tracker.rate(), tracker.global_rate())) if lr_scheduler: lr_scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples #test_utils.print_test_update(device, accuracy) print('device: {}, accuracy: {}'.format(device, accuracy)) return accuracy accuracy = 0.0 num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (batch_size * num_devices) for epoch in range(1, num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch print('global step: {}'.format(global_step)) return accuracy
def worker_fn(rank, flags): args = flags['args'] world_size = flags['world_size'] distributed = args.distributed is_primary = rank == 0 mlp.logging.use_fancy_colors() # ########## EXPERIMENT SETUP ########### torch.random.manual_seed(args.seed) # For reproducibility if distributed: logger_name = f"[Device {rank}] {os.path.basename(__file__)}" else: logger_name = os.path.basename(__file__) logger = get_logger(logger_name) # ######################################## # ############## DEVICE SETUP ############## xla_available = len(xm.get_xla_supported_devices()) > 0 if not xla_available: logger.error("No XLA devices available, unable to train") return if distributed: logger.info( f"Training using multiple XLA devices: Using XLA device {rank}/{world_size}" ) else: logger.info(f"Single XLA device mode : Using XLA device {rank} ") device = xm.xla_device() # ######################################## # ########## SETUP BATCH DATASETS ########## if distributed and not is_primary: xm.rendezvous("loading_data") training_data, test_data = load_data() if distributed and is_primary: xm.rendezvous("loading_data") training_sampler = None validation_sampler = None if distributed: training_sampler = torch.utils.data.distributed.DistributedSampler( training_data, num_replicas=world_size, rank=rank) validation_sampler = torch.utils.data.distributed.DistributedSampler( test_data, num_replicas=world_size, rank=rank) training_dataset = torch.utils.data.DataLoader( training_data, batch_size=args.batch_size, shuffle=(training_sampler is None), sampler=training_sampler, num_workers=3) # Using the test set as a validation set, just for demonstration purposes validation_dataset = torch.utils.data.DataLoader( test_data, batch_size=args.batch_size, shuffle=(validation_sampler is None), sampler=validation_sampler, num_workers=3) # ########################################## # ############ BUILD THE MODEL ############# classifier = build_model(args.hidden_size) train_model = TrainModel(classifier, device) # Move model to assigned GPU (see torch.cuda.set_device(args.local_rank)) classifier.to(device) # ############################################ # ############ SETUP OPTIMIZER ############# optimizer = torch.optim.Adam(classifier.parameters(), lr=args.learning_rate) # ########################################## # ############# SETUP TRAINING ############## trainer = mlp.trainers.DefaultTrainer(optimizers=optimizer, model_components=classifier) model_hyper_parameters = {"hidden_size": args.hidden_size} callbacks = create_callbacks_for(trainer, args.experiment_name, model_hyper_parameters, is_primary, validation_dataset, args.progress_log_period) manager = mlp.trainers.TrainingManager(trainer, training_dataset, num_epochs=args.num_epochs, callbacks=callbacks, experiment_data={"args": args}) trainer.set_training_model(train_model) # ########################################## # ################# START! ################# manager.start_training() # ########################################## logger.info("DONE.")
def test_num_local_devices(self): self.assertLen(xm.get_xla_supported_devices(), pjrt.addressable_device_count())
def test_get_real_xla_devices(self): devices = xm.get_xla_supported_devices() xla_devices = torch_xla._XLAC._xla_real_devices(devices) for device, xdevice in zip(devices, xla_devices): self.assertTrue(re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None)
print("c device: ", c.device, type(c), c.dtype) print("c[2x2] : ", c.narrow(0, 0, 2).narrow(1, 0, 2)) print("------") print( "GPU Time is {0:.6f} seconds, rate {1:.3f} GFlops for iter {2} " .format(elap1, m * n * k * 2 * 1.0 / (elap1 * 1000000000 / steps), steps)) print("------\n") if (args.testtpu): # import torch_xla import torch_xla.core.xla_model as xm alldev = xm.get_xla_supported_devices() allrealdev = xm.xla_real_devices(alldev) print("Found {0} XLA devices: {1}".format(len(allrealdev), allrealdev)) print(torch.__version__) dev = xm.xla_device() a = a.to(dev) b = b.to(dev) c = c.to(dev) measure_xla(a, b, warmups, m) elap1 = measure_xla(a, b, steps, m) print("c device: ", c.device, type(c), c.dtype) print("c[2x2] : ", c.narrow(0, 0, 2).narrow(1, 0, 2)) print("------") print(