def __exit__(self, type, value, traceback): if getattr(self, 'scope', None): # In ir.cpp ResetScopeContext we ensure that we have no remaining scope # before marking step. del self.scope xm.mark_step() super().__exit__(type, value, traceback)
def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): retval = None xm.mark_step() if not sum( v.item() for v in optimizer_state["found_inf_per_device"].values()): retval = optimizer.step(*args, **kwargs) return retval
def run_single(args, m, n, k): dtype = args.dtype device = args.device warmups = args.warmups steps = args.steps dt = torch.float32 if (dtype == "float16" or dtype == "half"): dt = torch.float16 elif (dtype == "bfloat16"): dt = torch.bfloat16 torch.manual_seed(0) elap = 0.0 a = torch.randn(m, k).to(dt) b = torch.randn(k, n).to(dt) c = torch.zeros(m, n).to(dt) if device == 'cpu': measure_cpu(a, b, warmups) elap = measure_cpu(a, b, steps) elif device == 'gpu': if torch.cuda.is_available(): # ncuda = torch.cuda.device_count() # print("There are {} cuda devices".format(ncuda)) # print("The first cuda device name is {} ".format(torch.cuda.get_device_name())) cuda0 = torch.device('cuda:0') with torch.cuda.device(cuda0): acuda = a.to(cuda0) bcuda = b.to(cuda0) measure_gpu(acuda, bcuda, warmups) elap = measure_gpu(acuda, bcuda, steps) else: print("CUDA is not available") sys.exit(1) else: # import torch_xla import torch_xla.core.xla_model as xm # alldev = xm.get_xla_supported_devices() # allrealdev = xm.xla_real_devices(alldev) # print("Found {0} XLA devices: {1}".format(len(allrealdev), allrealdev)) dev = xm.xla_device() a = a.to(dev) b = b.to(dev) c = c.to(dev) measure_xla(a, b, warmups) xm.mark_step() elap = measure_xla(a, b, steps) xm.mark_step() return elap
def train_tpu( model, device, optimizer, data_type, input_size, output_size, batch_size, args, ): import torch_xla.core.xla_model as xm loss_f = nn.CrossEntropyLoss().to(device) # model.train() start_time = time.time() for i in range(args.steps + args.warmups): data = torch.randn(batch_size, input_size, device=device) target = torch.randint(output_size, [batch_size], device=device, dtype=torch.long) # data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data).float() loss = loss_f(output, target) loss.backward() optimizer.step() xm.mark_step() if i < args.warmups: start_time = time.time() return time.time() - start_time, loss
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() import resource print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}") if step % FLAGS.log_steps == 0: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer) )
def distributed_init(args): if args.distributed_world_size == 1: raise ValueError( 'Cannot initialize distributed with distributed_world_size=1') if not getattr(args, 'tpu', False): if torch.distributed.is_initialized(): warnings.warn( 'Distributed is already initialized, cannot initialize twice!') else: logger.info('distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method, )) dist.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank, ) logger.info('initialized host {} as rank {}'.format( socket.gethostname(), args.distributed_rank, )) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) args.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm assert xm.xrt_world_size() == args.distributed_world_size args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() xm.rendezvous('distributed_init') # wait for all workers xm.mark_step() if is_master(args): logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) if args.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_cuda_manual_seed, ) except ImportError: raise ImportError('\n\nPlease install the megatron submodule:' '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron') initialize_model_parallel(args.model_parallel_size) model_parallel_cuda_manual_seed(args.seed) model_part_number = get_model_parallel_rank() args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number) return args.distributed_rank
def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" if self.tpu: import torch_xla.core.xla_model as xm xm.rendezvous("valid_step") # wait for all workers xm.mark_step() with torch.no_grad(): self.model.eval() self.criterion.eval() sample = self._prepare_sample(sample) if sample is None: sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: if self._dummy_batch == "DUMMY": self._dummy_batch = sample is_dummy_batch = False try: _loss, sample_size, logging_output = self.task.valid_step( sample, self.model, self.criterion) except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if not raise_oom: logger.warning( "ran out of memory in validation step, retrying batch" ) for p in self.model.parameters(): if p.grad is not None: p.grad = None # free some memory if self.cuda: torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) raise e logging_outputs = [logging_output] if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0.0 # gather logging outputs from all replicas if self.data_parallel_world_size > 1: logging_outputs, (sample_size, ) = self._aggregate_logging_outputs( logging_outputs, sample_size, ignore=is_dummy_batch, ) # log validation stats logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output
def __iter__(self): if self.rng_types is not None: synchronize_rng_states(self.rng_types, self.generator) state = AcceleratorState() for batch in super().__iter__(): if state.distributed_type == DistributedType.TPU: xm.mark_step() yield batch if self.device is None else send_to_device( batch, self.device)
def _test_optimizer(self, syncfree_optim_cls, ref_optim_cls, optim_kwargs={'lr': 1e-2}): device = xm.xla_device() loss_fn = nn.NLLLoss() # syncfree model torch.manual_seed(0) syncfree_model = MNIST().train().to(device) syncfree_optimizer = syncfree_optim_cls(syncfree_model.parameters(), **optim_kwargs) # reference model torch.manual_seed(0) ref_model = MNIST().train().to(device) ref_optimizer = ref_optim_cls(ref_model.parameters(), **optim_kwargs) # fake data data = torch.rand(32, 1, 28, 28).to(device) target = torch.zeros(32).to(device) # training loop for i in range(10): # syncfree step syncfree_optimizer.zero_grad() syncfree_output = syncfree_model(data) syncfree_loss = loss_fn(syncfree_output, target) syncfree_loss.backward() # mimick nan in the gradients if i % 2 == 0: xm._fetch_gradients(syncfree_optimizer)[0].mul_(torch.nan) found_inf = torch.tensor(1.0).to(device) else: found_inf = torch.tensor(0.0).to(device) xm.optimizer_step(syncfree_optimizer, optimizer_args={"found_inf": found_inf}) xm.mark_step() # reference step ref_optimizer.zero_grad() ref_output = ref_model(data) ref_loss = loss_fn(ref_output, target) ref_loss.backward() # mimick the effect of found_inf tensor if i % 2 != 0: xm.optimizer_step(ref_optimizer) xm.mark_step() # check loss np.testing.assert_allclose(ref_loss.cpu().detach().numpy(), syncfree_loss.cpu().detach().numpy(), rtol=1e-2, atol=1e-2) # check weight for p, p_ref in zip(syncfree_model.parameters(), ref_model.parameters()): np.testing.assert_allclose(p.cpu().detach().numpy(), p_ref.cpu().detach().numpy(), rtol=1e-2, atol=1e-2)
def next(self): if self._mark_step_batch_count == self._batches_yielded: self._batches_yielded = 0 xm.mark_step() else: self._batches_yielded += 1 item = self._loader.next_item(self._device) if item is None: raise StopIteration return item
def step(self): """ Takes optimizer step """ self.optimizer.param_groups[0]['lr'] = self.get_lr() if self.device.type == 'xla': xm.optimizer_step(self.optimizer) xm.mark_step() else: self.optimizer.step()
def test_allgather(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=3, size=8) output_tensors = [torch.zeros_like(tensor)] * 8 all_gather_pattern = r'%all\-gather\.\d+ = .+ all\-gather\(' pg_xla.allgather([output_tensors], [tensor]) hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors) hlo_matches(hlo, all_gather_pattern) # purge all computations attached the device. xm.mark_step()
def xla_run_method(rank, fit_method, learner_args, add_args, fit_args, ctrl_args): "run fit method on spawned process" sync_valid = True learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args) fit_args = setup_fit_cbs(rank, fit_args) fit_method(learner, **fit_args) xm.rendezvous('xla_run_method') learner.save('_xla_tmp_model', rendezvous=False) xm.mark_step()
def tpu_data_loader(args, itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl xm.rendezvous('tpu_data_loader') # wait for all workers xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), start=getattr(itr, 'n', 0), total=len(itr), )
def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
def test_allreduce(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=511, size=1024) opts = dist.AllreduceOptions() opts.reduceOp = dist.ReduceOp.SUM all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' with xm_cc_op_intercepted('all_reduce'): pg_xla.allreduce([tensor], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) hlo_matches(hlo, all_reduce_pattern) # purge all computations attached the device. xm.mark_step()
def tpu_data_loader(itr): import torch_xla.core.xla_model as xm import torch_xla.distributed.parallel_loader as pl from fairseq.data import iterators xm.rendezvous("tpu_data_loader") # wait for all workers xm.mark_step() device = xm.xla_device() return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), start=getattr(itr, "n", 0), total=len(itr), )
def test_reduce_scatter(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() input_list = [tensor] output = torch.zeros_like(tensor) pg_xla = get_process_group_xla(rank=0, size=len(input_list)) opts = dist.ReduceScatterOptions() opts.reduceOp = dist.ReduceOp.SUM reduce_scatter_pattern = r'%reduce\-scatter\.\d+ = .+ reduce\-scatter\(' pg_xla.reduce_scatter([output], [input_list], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([output]) hlo_matches(hlo, reduce_scatter_pattern) # purge all computations attached the device. xm.mark_step()
def test_asynchronous(self): flag = Event() assert not flag.is_set() def closure(): sleep(1) assert flag.is_set() xm.add_step_closure(closure, run_async=True) xm.mark_step() # should get to this part and complete before closure is finished running assert not flag.is_set() flag.set()
def test_synchronous(self): flag = Event() assert not flag.is_set() def closure(): sleep(1) assert not flag.is_set() flag.set() xm.add_step_closure(closure) xm.mark_step() # should not get to this part before closure is finished running assert flag.is_set()
def begin_epoch(self, epoch): """Called at the beginning of each epoch.""" logger.info("begin training epoch {}".format(epoch)) if self.quantizer is not None: self.quantizer.begin_epoch(epoch) # task specific setup per epoch self.task.begin_epoch(epoch, self.get_model()) if self.tpu: import torch_xla.core.xla_model as xm xm.rendezvous('begin_epoch') # wait for all workers xm.mark_step()
def next(self): if xp.get_tracer_marked_step(): xp.set_tracer_marked_step(False) self._batches_yielded += 1 else: if self._mark_step_batch_count <= self._batches_yielded: self._batches_yielded = 0 xm.mark_step() else: self._batches_yielded += 1 item = self._loader.next_item(self._device) if item is None: raise StopIteration return item
def test_broadcast(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() pg_xla = get_process_group_xla(rank=0, size=8) opts = dist.BroadcastOptions() opts.rootRank = 0 opts.rootTensor = 0 # xla doesn't have broadcast. We use all_reduce to implement broadcast. all_reduce_pattern = r'%all\-reduce\.\d+ = .+ all\-reduce\(' with xm_cc_op_intercepted('all_reduce'): pg_xla.broadcast([tensor], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) hlo_matches(hlo, all_reduce_pattern) # purge all computations attached the device. xm.mark_step()
def test_synchronous_exception(self): flag = Event() assert not flag.is_set() try: def closure(): flag.set() raise RuntimeError("Simulating exception in closure") xm.add_step_closure(closure) xm.mark_step() assert False # Should not reach here except RuntimeError as e: assert flag.is_set(), "Should have caught exception from closure"
def loop_with_amp(model, input, positions, target, causal_mask, optimizer, xla_enabled, autocast, scaler): with autocast(): loss = model(input, positions, target, batch_mask=causal_mask) if xla_enabled: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() else: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return loss
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def broadcast_xla_master_model_param(model): logger.info( "Broadcasting XLA model parameters and buffers from master process ..." ) parameters_and_buffers = [] for p in chain(model.parameters(), model.buffers()): # Set all params in non-master devices to zero so that all_reduce is equivalent # to broadcasting parameters from master to other devices. if not is_main(): zero = torch.tensor(0, dtype=p.data.dtype, device=p.data.device) p.data.mul_(zero) parameters_and_buffers.append(p.data) xm.wait_device_ops() xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers) xm.mark_step() xm.rendezvous("mmf.trainers.core.device.broadcast_xla_master_model_param") logger.info("Done!")
def loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler): with autocast(): outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs[0] if xla_enabled: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optim) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optim) scaler.update() xm.mark_step() else: scaler.scale(loss).backward() scaler.step(optim) scaler.update() return loss, optim
def xla_run_inference(rank, learner_args, add_args, inference_args, ctrl_args): sync_valid = True learner = make_xla_child_learner(rank, sync_valid, learner_args, add_args, ctrl_args) pred_args, master_cbs = setup_inference_args(rank, inference_args) if rank == 0 and len(master_cbs) > 0: learner.add_cbs(master_cbs) # learner.synced_cancel.before_fit() if rank == 0: learner.sync_recorder.orig_logger = learner.logger results = learner.inner_get_preds(**pred_args) xm.rendezvous('xla_run_inference') save_pred_results(rank, results) xm.mark_step()
def test_allreduce_with_mesh(self): device = xm.xla_device() tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() set_world_size(6) ranks = [2, 3] world_rank = 3 set_world_rank(world_rank) with new_group_barrier_disabled(): new_pg = dist.new_group(ranks=ranks) opts = dist.AllreduceOptions() opts.reduceOp = dist.ReduceOp.SUM all_reduce_pattern = (r'%all\-reduce\.\d+ = .+ all\-reduce\(.+\), .*' r'replica_groups=\{\{0,1\},\{2,3\},\{4,5\}\}') with xm_cc_op_intercepted('all_reduce'): new_pg.allreduce([tensor], opts) hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) hlo_matches(hlo, all_reduce_pattern) # purge all computations attached the device. xm.mark_step()