def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb if src != self.get_rank(): tensor.fill_(0.0) xm.all_reduce("sum", [tensor]) return tensor
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb group_size = self.get_world_size() output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device) output[self.get_rank() % group_size] = tensor xm.all_reduce("sum", [output,]) return output.reshape(-1, *output.shape[2:])
def reduce_gradients(params): # compared to xm.reduce_gradients, this takes the params directly # instead of extracting them from an optimizer instance count = torch_xla._XLAC._xla_get_replication_devices_count() if count > 1: gradients = [p.grad for p in params if p.grad is not None] xm.all_reduce('sum', gradients, scale=1.0 / count, groups=None)
def _compute_nproc_per_node(self) -> int: tensor = torch.tensor([self.get_local_rank() + 1.0], dtype=torch.float).to(self.device()) xm.all_reduce("max", [ tensor, ]) return int(tensor.item())
def train_loop_fn(loader, epoch): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() import resource print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}") if step % FLAGS.log_steps == 0: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure( _train_update, args=(device, step, loss, tracker, epoch, writer) )
def _do_all_reduce(self, tensor: torch.Tensor, op: str = "SUM") -> torch.Tensor: if op not in self._reduce_op_map: raise ValueError(f"Unsupported reduction operation: '{op}'") op = self._reduce_op_map[op] xm.all_reduce(op, [tensor]) return tensor
def broadcast(self, tensors, opts): root_tensor = tensors[opts.rootTensor] root_rank = opts.rootRank if root_rank != self.rank(): with torch.no_grad(): root_tensor.zero_() xm.all_reduce(xm.REDUCE_SUM, [root_tensor], groups=self._mesh) return WorkXla([root_tensor])
def allreduce(self, tensors, all_reduce_options): reduce_type = self._get_reduce_type(all_reduce_options.reduceOp) # TODO(hjm-aws): implement all_reduce_options.timeout. xm.all_reduce(reduce_type, tensors, groups=self._mesh, pin_layout=False) return _ret_work(tensors)
def train_loop_fn(loader, epoch): if FLAGS.fine_grained_metrics: epoch_start_time = time.time() step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], [] else: tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): if FLAGS.fine_grained_metrics: step_start_time = time.time() optimizer.zero_grad() if FLAGS.fine_grained_metrics: fwd_start_time = time.time() with autocast(): output = model(data) loss = loss_fn(output, target) if FLAGS.fine_grained_metrics: fwd_end_time = time.time() fwd_latency = fwd_end_time - fwd_start_time bwd_start_time = time.time() scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() if lr_scheduler: lr_scheduler.step() if FLAGS.fine_grained_metrics: bwd_end_time = time.time() bwd_latency = bwd_end_time - bwd_start_time step_latency = bwd_end_time - step_start_time step_latency_tracker.append(step_latency) bwd_latency_tracker.append(bwd_latency) fwd_latency_tracker.append(fwd_latency) else: tracker.add(FLAGS.batch_size) if step % FLAGS.log_steps == 0: if FLAGS.fine_grained_metrics: print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker))) else: # _train_update(device, step, loss, tracker, epoch, writer) xm.add_step_closure(_train_update, args=(device, step, loss, tracker, epoch, writer)) if FLAGS.fine_grained_metrics: epoch_end_time = time.time() epoch_latency = epoch_end_time - epoch_start_time print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\ epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
def all_reduce(tensor, group=None): if isinstance(group, tuple) and group[0] == 'tpu': import torch_xla.core.xla_model as xm return xm.all_reduce('sum', [tensor], groups=group[1]) else: if group is None: group = get_default_group() return dist.all_reduce(tensor, group=group)
def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for step, (data, target) in enumerate(loader): optimizer.zero_grad() with autocast(): output = model(data) loss = loss_fn(output, target) scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() tracker.add(flags.batch_size) if step % flags.log_steps == 0: xm.add_step_closure(_train_update, args=(device, step, loss, tracker, writer))
def loop_with_amp(model, input, positions, target, causal_mask, optimizer, xla_enabled, autocast, scaler): with autocast(): loss = model(input, positions, target, batch_mask=causal_mask) if xla_enabled: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optimizer) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optimizer) scaler.update() xm.mark_step() else: scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return loss
def _mp_fn(index): device = xm.xla_device() if xm.xla_device_hw(device) != 'CPU': ones = torch.ones((2, 3)) twos = ones + 1.0 xones = ones.to(device) xtwos = twos.to(device) xm.all_reduce(xm.REDUCE_SUM, [xones, xtwos]) if (not xones.cpu().allclose(ones * float(xm.xrt_world_size())) or not xtwos.cpu().allclose(twos * float(xm.xrt_world_size()))): print('CrossReplicaSum produced wrong reductions', file=sys.stderr) print(xones, file=sys.stderr) sys.exit(1) else: print('Default device {} does not support replication'.format(device), file=sys.stderr)
def loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler): with autocast(): outputs = model(input_ids, attention_mask=attention_mask, labels=labels) loss = outputs[0] if xla_enabled: scaler.scale(loss).backward() gradients = xm._fetch_gradients(optim) xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) scaler.step(optim) scaler.update() xm.mark_step() else: scaler.scale(loss).backward() scaler.step(optim) scaler.update() return loss, optim
def after_batch(self): if not getattr(self.learn, 'inner_xla', False): return # skip if not spawned cancel_fit = xm.all_reduce(xm.REDUCE_SUM, self.sync_cancel_fit) if cancel_fit > self.zero: # a rank triggered a cancel self.dl.close() # close per device loader raise CancelFitException()
def broadcast_xla_master_model_param(model): logger.info( "Broadcasting XLA model parameters and buffers from master process ..." ) parameters_and_buffers = [] for p in chain(model.parameters(), model.buffers()): # Set all params in non-master devices to zero so that all_reduce is equivalent # to broadcasting parameters from master to other devices. if not is_main(): zero = torch.tensor(0, dtype=p.data.dtype, device=p.data.device) p.data.mul_(zero) parameters_and_buffers.append(p.data) xm.wait_device_ops() xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers) xm.mark_step() xm.rendezvous("mmf.trainers.core.device.broadcast_xla_master_model_param") logger.info("Done!")
def main_fold(fold): import time import torch.nn as nn import torch.optim as optim import torch_xla.core.xla_model as xm from ignite.engine import Engine, Events device = xm.xla_device(fold) comp_model = _XlaDistModel.create_from_context() assert comp_model.device() == device model = nn.Linear(100, 10) model.to(device) # Move model before creating optimizer optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) def training_step(engine, _): data = torch.rand(4, 100, device=device) model.train() data = data.to(device) optimizer.zero_grad() output = model(data) loss = output.sum() loss.backward() xm.optimizer_step(optimizer, barrier=True) return loss.item() trainer = Engine(training_step) # THIS CAN BE A CAUSE OF CRASH if DEVICE is OTHER THAN device tensor = torch.tensor([fold + 1.0], dtype=torch.float).to(comp_model.device()) xm.all_reduce("max", [ tensor, ]) time.sleep(0.01 * fold) @trainer.on(Events.ITERATION_COMPLETED) def log_progress(): print(".", end=" ") trainer.run([0] * 100, max_epochs=2)
def all_reduce_grads(self): gradients = [] for p in self.parameters(): if not p.requires_grad: continue if p.grad is None: p.grad = torch.zeros_like(p) if p.grad.requires_grad: raise RuntimeError( "TPUDistributedDataParallel only works with gradients that don't " "require grad") gradients.append(p.grad) import torch_xla.core.xla_model as xm xm.all_reduce( 'sum', gradients, scale=1. / self.world_size, groups=self.process_group[1], )
def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor: """Syncs ``tensor`` over ``world_size`` in distributed mode. Args: tensor: tensor to sync across the processes. mode: tensor synchronization type, should be one of 'sum' or 'mean'. Default is 'mean'. Returns: torch.Tensor with synchronized values. Raises: ValueError: if mode is out of ``sum``, ``mean``. """ # return tensor if mode not in {"sum", "mean"}: raise ValueError(f"Unknown sync_type '{mode}'") if mode == "sum": return xm.all_reduce("sum", tensor) elif mode == "mean": return xm.all_reduce("sum", tensor, scale=1.0 / self.world_size)
def all_reduce(tensor, group, op="sum"): if use_xla(): assert isinstance(group, tuple) and group[0] == "tpu" tensor = [tensor] # wrap in a list to make xm.all_reduce in-place return xm.all_reduce(op, tensor, groups=group[1])[0] else: if op == "sum": op = dist.ReduceOp.SUM elif op == "max": op = dist.ReduceOp.MAX else: raise NotImplementedError dist.all_reduce(tensor, op=op, group=group) return tensor
def validation_epoch_end(self, outputs): avg_loss = torch.stack([ x['log']['val_mlm_loss'] for x in outputs if 'val_mlm_loss' in x['log'] ]).mean() if self.use_ddp: # TODO: PTL is already doing this. Is it still needed here? # https://github.com/PyTorchLightning/pytorch-lightning/blob/0.8.5/pytorch_lightning/metrics/converters.py#L251 torch.distributed.all_reduce(avg_loss, op=torch.distributed.ReduceOp.SUM) avg_loss /= torch.distributed.get_world_size() elif self.use_tpu: avg_loss = xm.all_reduce(xm.REDUCE_SUM, avg_loss) / xm.xrt_world_size() logs = {'val_mlm_loss': avg_loss} return {'log': logs, 'progress_bar': logs, "val_loss": avg_loss}
def reduce_dict(dictionary): world_size = get_world_size() if world_size < 2: return dictionary with torch.no_grad(): if len(dictionary) == 0: return dictionary keys, values = zip(*sorted(dictionary.items())) values = torch.stack(values, dim=0) if is_xla(): values = xm.all_reduce("sum", [values], scale=1.0 / world_size)[0] else: dist.reduce(values, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case values /= world_size reduced_dict = {k: v for k, v in zip(keys, values)} return reduced_dict
def after_epoch(self): if not getattr(self.learn,'inner_xla',False): return # skip if not spawned if 'recorder' not in self.learn.cbs.attrgot('name'): all_metrics = { 'train_mets': L([]), 'valid_mets': L([]), } else: all_metrics = { 'train_mets': self.recorder._train_mets, 'valid_mets': self.recorder._valid_mets, } # send metrics data to sync ranks across spawned processes device = self.learn.xla_training.pdevice packed_metrics = pack_metrics(all_metrics, device) # convert metrics to tensor list on TPU reduced_metrics = xm.all_reduce(xm.REDUCE_SUM, packed_metrics) xm.mark_step() if xm.is_master_ordinal(): all_metrics = restore_metrics(reduced_metrics, all_metrics) # convert list to metric objects for m in self.recorder._train_mets: self.sync_log += _maybe_item(m) for m in self.recorder._valid_mets: self.sync_log += _maybe_item(m) self.learn.final_record = self.sync_log[1:].copy() del self.recorder.values[-1] # remove last entry added by recorder self.recorder.values.append(self.learn.final_record) # add updated metrics if self.recorder.add_time: updated_time = (time.time() - self.recorder.start_epoch) self.sync_log.append(format_time(updated_time)) self.recorder.log = self.sync_log self._sync_stats_log(self.sync_log) # write_stats to output self.learn.logger = self.orig_logger # restore orig logger after skipping recorder.logger(log)
def allreduce(self, tensors, all_reduce_options): reduce_type = self._get_reduce_type(all_reduce_options.reduceOp) # TODO(hjm-aws): implement all_reduce_options.timeout. xm.all_reduce(reduce_type, tensors, groups=self._mesh) return WorkXla(tensors)
def all_reduce(self, collectiveArgs, retFlag=False): retObj = xm.all_reduce(collectiveArgs.op, [collectiveArgs.ipTensor]) if retFlag: return retObj
def train(self): hps = self.state.hps ss = self.state current_stats = {} writer_stats = {} # for resuming the learning rate sorted_lr_steps = sorted(self.learning_rates.keys()) lr_index = util.greatest_lower_bound(sorted_lr_steps, ss.data.global_step) ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]]) if ss.model.bn_type != 'none': sorted_as_steps = sorted(self.anneal_schedule.keys()) as_index = util.greatest_lower_bound(sorted_as_steps, ss.data.global_step) ss.model.objective.update_anneal_weight( self.anneal_schedule[sorted_as_steps[as_index]]) if ss.model.bn_type in ('vqvae', 'vqvae-ema'): ss.model.init_codebook(self.data_iter, 10000) start_time = time.time() for batch_num, batch in enumerate(self.device_loader): wav, mel, voice, jitter, position = batch global_step = len(ss.data.dataset) * position[0] + position[1] # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr) # stderr.flush() if (batch_num % hps.save_interval == 0 and batch_num != 0): self.save_checkpoint(position) if hps.skip_loop_body: continue lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step) ss.update_learning_rate( self.learning_rates[sorted_lr_steps[lr_index]]) # if ss.data.global_step in self.learning_rates: # ss.update_learning_rate(self.learning_rates[ss.data.global_step]) if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule: ss.model.objective.update_anneal_weight( self.anneal_schedule[ss.data.global_step]) ss.optim.zero_grad() quant, self.target, loss = self.state.model.run( wav, mel, voice, jitter) self.probs = self.softmax(quant) self.mel_enc_input = mel # print(f'after model.run', file=stderr) # stderr.flush() loss.backward() # print(f'after loss.backward()', file=stderr) # stderr.flush() if batch_num % hps.progress_interval == 0: pars_copy = [p.data.clone() for p in ss.model.parameters()] # print(f'after pars_copy', file=stderr) # stderr.flush() if self.is_tpu: xm.optimizer_step(ss.optim) else: ss.optim.step() ss.optim_step += 1 if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000: ss.model.bottleneck.update_codebook() tprb_m = self.avg_prob_target() if batch_num % hps.progress_interval == 0: iterator = zip(pars_copy, ss.model.named_parameters()) uw_ratio = { np[0]: t.norm(c - np[1].data) / c.norm() for c, np in iterator } writer_stats.update({'uwr': uw_ratio}) if self.is_tpu: count = torch_xla._XLAC._xla_get_replication_devices_count( ) loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m], scale=1.0 / count) # loss_red = xm.all_reduce('all_loss', loss, reduce_mean) # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean) else: loss_red = loss tprb_red = tprb_m writer_stats.update({ 'loss_r': loss_red, 'tprb_r': tprb_red, 'optim_step': ss.optim_step }) current_stats.update({ 'optim_step': ss.optim_step, 'gstep': global_step, # 'gstep': ss.data.global_step, 'epoch': position[0], 'step': position[1], # 'loss': loss, 'lrate': ss.optim.param_groups[0]['lr'], # 'tprb_m': tprb_m, # 'pk_d_m': avg_peak_dist }) current_stats.update(ss.model.objective.metrics) if ss.model.bn_type in ('vae'): current_stats['free_nats'] = ss.model.objective.free_nats current_stats['anneal_weight'] = \ ss.model.objective.anneal_weight.item() if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'): current_stats.update(ss.model.encoder.metrics) if self.is_tpu: xm.add_step_closure(self.train_update, args=(writer_stats, current_stats)) else: self.train_update(writer_stats, current_stats) # if not self.is_tpu or xm.is_master_ordinal(): # if batch_num in range(25, 50) or batch_num in range(75, 100): stderr.flush() elapsed = time.time() - start_time
def train_step(self, samples, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch == "DUMMY": self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() metrics.log_start_time("train_wall", priority=800, round=0) # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) is_dummy_batch = True else: is_dummy_batch = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.data_parallel_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, criterion=self.criterion, optimizer=self.optimizer, update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) del loss logging_outputs.append(logging_output) sample_size += sample_size_i # emptying the CUDA cache after the first step can # reduce the chance of OOM if self.cuda and self.get_num_updates() == 0: torch.cuda.empty_cache() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e logger.warning( "attempting to recover from OOM in forward/backward pass" ) ooms += 1 self.zero_grad() if self.cuda: torch.cuda.empty_cache() if self.args.distributed_world_size == 1: return None else: raise e if self.tpu and i < len(samples) - 1: # tpu-comment: every XLA operation before marking step is # appended to the IR graph, and processing too many batches # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass import torch_xla.core.xla_model as xm xm.mark_step() if is_dummy_batch: if torch.is_tensor(sample_size): sample_size.zero_() else: sample_size *= 0. if torch.is_tensor(sample_size): sample_size = sample_size.float() else: sample_size = float(sample_size) # gather logging outputs from all replicas if self._sync_stats(): train_time = self._local_cumulative_training_time() logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, ) self._cumulative_training_time = total_train_time / self.data_parallel_world_size if hasattr(self.model, 'all_reduce'): self.model.all_reduce() overflow = False try: if self.tpu and self.data_parallel_world_size > 1: import torch_xla.core.xla_model as xm gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get # (sum_of_gradients / sample_size). if not self.args.use_bmuf: self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size) elif sample_size > 0: # BMUF needs to check sample size num = self.data_parallel_world_size if self._sync_stats() else 1 self.optimizer.multiply_grads(num / sample_size) with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) # check that grad norms are consistent across workers if ( not self.args.use_bmuf and self.args.distributed_wrapper != 'SlowMo' and not self.tpu ): self._check_grad_norms(grad_norm) with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails with NanDetector(self.model): self.task.train_step( sample, self.model, self.criterion, self.optimizer, self.get_num_updates(), ignore_grad=False ) raise except OverflowError as e: overflow = True logger.info("NOTE: overflow detected, " + str(e)) grad_norm = torch.tensor(0.).cuda() self.zero_grad() except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) logger.error("OOM during optimization, irrecoverable") raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs import torch_xla.core.xla_model as xm xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for # optimization self._check_xla_compilation() else: # log stats logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) # clear CUDA cache to reduce memory fragmentation if ( self.cuda and self.args.empty_cache_freq > 0 and ( (self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq ) == 0 ): torch.cuda.empty_cache() if self.args.fp16: metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0) metrics.log_stop_time("train_wall") return logging_output
def all_reduce(self, collectiveArgs, retFlag=False): retObj = xm.all_reduce(collectiveArgs.op, [collectiveArgs.ipTensor]) if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(retObj) if retFlag: return retObj
def backward(ctx, grad_output): dim = ctx.dim all_grad_output = xm.all_reduce(xm.REDUCE_SUM, grad_output) return all_grad_output.select(dim, xm.get_ordinal()), None
def forward(ctx, input, reduce_type, scale, groups): ctx.reduce_type = reduce_type ctx.scale = scale output = xm.all_reduce(reduce_type, input, scale=scale, groups=groups) ctx.save_for_backward(input, output) return output