Esempio n. 1
0
 def _do_broadcast(self, tensor: torch.Tensor,
                   src: int) -> torch.Tensor:
     # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb
     if src != self.get_rank():
         tensor.fill_(0.0)
     xm.all_reduce("sum", [tensor])
     return tensor
Esempio n. 2
0
 def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
     # from https://github.com/jysohn23/xla/blob/model-parallel-colab/Gather_Scatter_Broadcast_PyTorch_XLA.ipynb
     group_size = self.get_world_size()
     output = torch.zeros((group_size,) + tensor.shape, dtype=tensor.dtype, device=tensor.device)
     output[self.get_rank() % group_size] = tensor
     xm.all_reduce("sum", [output,])
     return output.reshape(-1, *output.shape[2:])
Esempio n. 3
0
def reduce_gradients(params):
    # compared to xm.reduce_gradients, this takes the params directly
    # instead of extracting them from an optimizer instance
    count = torch_xla._XLAC._xla_get_replication_devices_count()
    if count > 1:
        gradients = [p.grad for p in params if p.grad is not None]
        xm.all_reduce('sum', gradients, scale=1.0 / count, groups=None)
Esempio n. 4
0
File: xla.py Progetto: uribgp/ignite
 def _compute_nproc_per_node(self) -> int:
     tensor = torch.tensor([self.get_local_rank() + 1.0],
                           dtype=torch.float).to(self.device())
     xm.all_reduce("max", [
         tensor,
     ])
     return int(tensor.item())
    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)

            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()

            import resource
            print(f" CPU Usage After: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")

            if step % FLAGS.log_steps == 0:
                # _train_update(device, step, loss, tracker, epoch, writer)
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer)
                )
Esempio n. 6
0
 def _do_all_reduce(self,
                    tensor: torch.Tensor,
                    op: str = "SUM") -> torch.Tensor:
     if op not in self._reduce_op_map:
         raise ValueError(f"Unsupported reduction operation: '{op}'")
     op = self._reduce_op_map[op]
     xm.all_reduce(op, [tensor])
     return tensor
Esempio n. 7
0
    def broadcast(self, tensors, opts):
        root_tensor = tensors[opts.rootTensor]
        root_rank = opts.rootRank
        if root_rank != self.rank():
            with torch.no_grad():
                root_tensor.zero_()
        xm.all_reduce(xm.REDUCE_SUM, [root_tensor], groups=self._mesh)

        return WorkXla([root_tensor])
Esempio n. 8
0
    def allreduce(self, tensors, all_reduce_options):
        reduce_type = self._get_reduce_type(all_reduce_options.reduceOp)

        # TODO(hjm-aws): implement all_reduce_options.timeout.
        xm.all_reduce(reduce_type,
                      tensors,
                      groups=self._mesh,
                      pin_layout=False)
        return _ret_work(tensors)
Esempio n. 9
0
    def train_loop_fn(loader, epoch):
        if FLAGS.fine_grained_metrics:
            epoch_start_time = time.time()
            step_latency_tracker, bwd_latency_tracker, fwd_latency_tracker = [], [], []
        else:
            tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(loader):
            if FLAGS.fine_grained_metrics:
                step_start_time = time.time()
            optimizer.zero_grad()
            if FLAGS.fine_grained_metrics:
                fwd_start_time = time.time()
            with autocast():
                output = model(data)
                loss = loss_fn(output, target)
            if FLAGS.fine_grained_metrics:
                fwd_end_time = time.time()
                fwd_latency = fwd_end_time - fwd_start_time

                bwd_start_time = time.time()
            scaler.scale(loss).backward()
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
            scaler.step(optimizer)
            scaler.update()
            xm.mark_step()
            if lr_scheduler:
                lr_scheduler.step()
            if FLAGS.fine_grained_metrics:
                bwd_end_time = time.time()
                bwd_latency = bwd_end_time - bwd_start_time

                step_latency = bwd_end_time - step_start_time
                step_latency_tracker.append(step_latency)
                bwd_latency_tracker.append(bwd_latency)
                fwd_latency_tracker.append(fwd_latency)
            else:
                tracker.add(FLAGS.batch_size)
            if step % FLAGS.log_steps == 0:
                if FLAGS.fine_grained_metrics:
                    print('FineGrainedMetrics :: Epoch={} Step={} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                                epoch, step, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
                else:
                    # _train_update(device, step, loss, tracker, epoch, writer)
                    xm.add_step_closure(_train_update,
                                        args=(device, step, loss, tracker,
                                              epoch, writer))
        if FLAGS.fine_grained_metrics:
            epoch_end_time = time.time()
            epoch_latency = epoch_end_time - epoch_start_time
            print('FineGrainedMetrics :: Epoch={} Epoch(s)={:.} Rate(DataPoints/s)[p50]={:.1f} BatchSize={} Step(s/Batch)[p50]={:.2f} Fwd(s/Batch)[p50]={:.4f} Bwd(s/Batch)[p50]={:.4f}'.format(\
                                            epoch, epoch_latency, FLAGS.batch_size/p50(step_latency_tracker), FLAGS.batch_size, p50(step_latency_tracker), p50(bwd_latency_tracker), p50(fwd_latency_tracker)))
Esempio n. 10
0
def all_reduce(tensor, group=None):
    if isinstance(group, tuple) and group[0] == 'tpu':
        import torch_xla.core.xla_model as xm
        return xm.all_reduce('sum', [tensor], groups=group[1])
    else:
        if group is None:
            group = get_default_group()
        return dist.all_reduce(tensor, group=group)
Esempio n. 11
0
 def train_loop_fn(loader):
     tracker = xm.RateTracker()
     model.train()
     for step, (data, target) in enumerate(loader):
         optimizer.zero_grad()
         with autocast():
             output = model(data)
             loss = loss_fn(output, target)
         scaler.scale(loss).backward()
         gradients = xm._fetch_gradients(optimizer)
         xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
         scaler.step(optimizer)
         scaler.update()
         tracker.add(flags.batch_size)
         if step % flags.log_steps == 0:
             xm.add_step_closure(_train_update,
                                 args=(device, step, loss, tracker, writer))
Esempio n. 12
0
def loop_with_amp(model, input, positions, target, causal_mask, optimizer,
                  xla_enabled, autocast, scaler):
    with autocast():
        loss = model(input, positions, target, batch_mask=causal_mask)
    if xla_enabled:
        scaler.scale(loss).backward()
        gradients = xm._fetch_gradients(optimizer)
        xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
        scaler.step(optimizer)
        scaler.update()
        xm.mark_step()
    else:
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    return loss
Esempio n. 13
0
def _mp_fn(index):
    device = xm.xla_device()
    if xm.xla_device_hw(device) != 'CPU':
        ones = torch.ones((2, 3))
        twos = ones + 1.0
        xones = ones.to(device)
        xtwos = twos.to(device)
        xm.all_reduce(xm.REDUCE_SUM, [xones, xtwos])

        if (not xones.cpu().allclose(ones * float(xm.xrt_world_size())) or
                not xtwos.cpu().allclose(twos * float(xm.xrt_world_size()))):
            print('CrossReplicaSum produced wrong reductions', file=sys.stderr)
            print(xones, file=sys.stderr)
            sys.exit(1)
    else:
        print('Default device {} does not support replication'.format(device),
              file=sys.stderr)
Esempio n. 14
0
def loop_with_amp(model, input_ids, attention_mask, labels, optim, xla_enabled, autocast, scaler):
    with autocast():
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
    
    if xla_enabled:
        scaler.scale(loss).backward()
        gradients = xm._fetch_gradients(optim)
        xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
        scaler.step(optim)
        scaler.update()
        xm.mark_step()
    else:
        scaler.scale(loss).backward()
        scaler.step(optim)
        scaler.update()

    return loss, optim
Esempio n. 15
0
    def after_batch(self):
        if not getattr(self.learn, 'inner_xla', False):
            return  # skip if not spawned

        cancel_fit = xm.all_reduce(xm.REDUCE_SUM, self.sync_cancel_fit)

        if cancel_fit > self.zero:  # a rank triggered a cancel
            self.dl.close()  # close per device loader
            raise CancelFitException()
Esempio n. 16
0
def broadcast_xla_master_model_param(model):
    logger.info(
        "Broadcasting XLA model parameters and buffers from master process ..."
    )

    parameters_and_buffers = []
    for p in chain(model.parameters(), model.buffers()):
        # Set all params in non-master devices to zero so that all_reduce is equivalent
        # to broadcasting parameters from master to other devices.
        if not is_main():
            zero = torch.tensor(0, dtype=p.data.dtype, device=p.data.device)
            p.data.mul_(zero)
        parameters_and_buffers.append(p.data)
    xm.wait_device_ops()
    xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers)
    xm.mark_step()
    xm.rendezvous("mmf.trainers.core.device.broadcast_xla_master_model_param")
    logger.info("Done!")
Esempio n. 17
0
def main_fold(fold):
    import time
    import torch.nn as nn
    import torch.optim as optim
    import torch_xla.core.xla_model as xm
    from ignite.engine import Engine, Events

    device = xm.xla_device(fold)

    comp_model = _XlaDistModel.create_from_context()
    assert comp_model.device() == device

    model = nn.Linear(100, 10)

    model.to(device)  # Move model before creating optimizer
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    def training_step(engine, _):
        data = torch.rand(4, 100, device=device)
        model.train()
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = output.sum()
        loss.backward()
        xm.optimizer_step(optimizer, barrier=True)
        return loss.item()

    trainer = Engine(training_step)

    # THIS CAN BE A CAUSE OF CRASH if DEVICE is OTHER THAN device
    tensor = torch.tensor([fold + 1.0],
                          dtype=torch.float).to(comp_model.device())
    xm.all_reduce("max", [
        tensor,
    ])

    time.sleep(0.01 * fold)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_progress():
        print(".", end=" ")

    trainer.run([0] * 100, max_epochs=2)
    def all_reduce_grads(self):
        gradients = []
        for p in self.parameters():
            if not p.requires_grad:
                continue
            if p.grad is None:
                p.grad = torch.zeros_like(p)
            if p.grad.requires_grad:
                raise RuntimeError(
                    "TPUDistributedDataParallel only works with gradients that don't "
                    "require grad")
            gradients.append(p.grad)

        import torch_xla.core.xla_model as xm
        xm.all_reduce(
            'sum',
            gradients,
            scale=1. / self.world_size,
            groups=self.process_group[1],
        )
Esempio n. 19
0
    def sync_tensor(self, tensor: torch.Tensor, mode: str) -> torch.Tensor:
        """Syncs ``tensor`` over ``world_size`` in distributed mode.

        Args:
            tensor: tensor to sync across the processes.
            mode: tensor synchronization type,
                should be one of 'sum' or 'mean'.
                Default is 'mean'.

        Returns:
            torch.Tensor with synchronized values.

        Raises:
            ValueError: if mode is out of ``sum``, ``mean``.
        """
        # return tensor
        if mode not in {"sum", "mean"}:
            raise ValueError(f"Unknown sync_type '{mode}'")
        if mode == "sum":
            return xm.all_reduce("sum", tensor)
        elif mode == "mean":
            return xm.all_reduce("sum", tensor, scale=1.0 / self.world_size)
Esempio n. 20
0
def all_reduce(tensor, group, op="sum"):
    if use_xla():
        assert isinstance(group, tuple) and group[0] == "tpu"
        tensor = [tensor]  # wrap in a list to make xm.all_reduce in-place
        return xm.all_reduce(op, tensor, groups=group[1])[0]
    else:
        if op == "sum":
            op = dist.ReduceOp.SUM
        elif op == "max":
            op = dist.ReduceOp.MAX
        else:
            raise NotImplementedError
        dist.all_reduce(tensor, op=op, group=group)
        return tensor
Esempio n. 21
0
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([
            x['log']['val_mlm_loss'] for x in outputs
            if 'val_mlm_loss' in x['log']
        ]).mean()
        if self.use_ddp:
            # TODO: PTL is already doing this. Is it still needed here?
            # https://github.com/PyTorchLightning/pytorch-lightning/blob/0.8.5/pytorch_lightning/metrics/converters.py#L251
            torch.distributed.all_reduce(avg_loss,
                                         op=torch.distributed.ReduceOp.SUM)
            avg_loss /= torch.distributed.get_world_size()
        elif self.use_tpu:
            avg_loss = xm.all_reduce(xm.REDUCE_SUM,
                                     avg_loss) / xm.xrt_world_size()

        logs = {'val_mlm_loss': avg_loss}
        return {'log': logs, 'progress_bar': logs, "val_loss": avg_loss}
Esempio n. 22
0
def reduce_dict(dictionary):
    world_size = get_world_size()
    if world_size < 2:
        return dictionary

    with torch.no_grad():
        if len(dictionary) == 0:
            return dictionary

        keys, values = zip(*sorted(dictionary.items()))
        values = torch.stack(values, dim=0)

        if is_xla():
            values = xm.all_reduce("sum", [values], scale=1.0 / world_size)[0]
        else:
            dist.reduce(values, dst=0)
            if dist.get_rank() == 0:
                # only main process gets accumulated, so only divide by
                # world_size in this case
                values /= world_size
        reduced_dict = {k: v for k, v in zip(keys, values)}
    return reduced_dict
Esempio n. 23
0
    def after_epoch(self):
        if not getattr(self.learn,'inner_xla',False):
            return # skip if not spawned

        if 'recorder' not in self.learn.cbs.attrgot('name'):
            all_metrics = {
                'train_mets': L([]),
                'valid_mets': L([]),
            }
        else:
            all_metrics = {
                'train_mets': self.recorder._train_mets,
                'valid_mets': self.recorder._valid_mets,
            }
        # send metrics data to sync ranks across spawned processes
        device = self.learn.xla_training.pdevice
        packed_metrics = pack_metrics(all_metrics, device) # convert metrics to tensor list on TPU
        reduced_metrics = xm.all_reduce(xm.REDUCE_SUM, packed_metrics)
        xm.mark_step()
        if xm.is_master_ordinal():
            all_metrics = restore_metrics(reduced_metrics, all_metrics) # convert list to metric objects
            for m in self.recorder._train_mets:
                self.sync_log += _maybe_item(m)

            for m in self.recorder._valid_mets:
                self.sync_log += _maybe_item(m)

            self.learn.final_record = self.sync_log[1:].copy()
            del self.recorder.values[-1] # remove last entry added by recorder
            self.recorder.values.append(self.learn.final_record) # add updated metrics
            if self.recorder.add_time:
                updated_time = (time.time() - self.recorder.start_epoch)
                self.sync_log.append(format_time(updated_time))
            self.recorder.log = self.sync_log
            self._sync_stats_log(self.sync_log) # write_stats to output
            self.learn.logger = self.orig_logger # restore orig logger after skipping recorder.logger(log)
Esempio n. 24
0
    def allreduce(self, tensors, all_reduce_options):
        reduce_type = self._get_reduce_type(all_reduce_options.reduceOp)

        # TODO(hjm-aws): implement all_reduce_options.timeout.
        xm.all_reduce(reduce_type, tensors, groups=self._mesh)
        return WorkXla(tensors)
Esempio n. 25
0
 def all_reduce(self, collectiveArgs, retFlag=False):
     retObj = xm.all_reduce(collectiveArgs.op, [collectiveArgs.ipTensor])
     if retFlag:
         return retObj
Esempio n. 26
0
    def train(self):
        hps = self.state.hps
        ss = self.state
        current_stats = {}
        writer_stats = {}

        # for resuming the learning rate
        sorted_lr_steps = sorted(self.learning_rates.keys())
        lr_index = util.greatest_lower_bound(sorted_lr_steps,
                                             ss.data.global_step)
        ss.update_learning_rate(self.learning_rates[sorted_lr_steps[lr_index]])

        if ss.model.bn_type != 'none':
            sorted_as_steps = sorted(self.anneal_schedule.keys())
            as_index = util.greatest_lower_bound(sorted_as_steps,
                                                 ss.data.global_step)
            ss.model.objective.update_anneal_weight(
                self.anneal_schedule[sorted_as_steps[as_index]])

        if ss.model.bn_type in ('vqvae', 'vqvae-ema'):
            ss.model.init_codebook(self.data_iter, 10000)

        start_time = time.time()

        for batch_num, batch in enumerate(self.device_loader):
            wav, mel, voice, jitter, position = batch
            global_step = len(ss.data.dataset) * position[0] + position[1]

            # print(f'replica {self.replica_index}, batch {batch_num}', file=stderr)
            # stderr.flush()
            if (batch_num % hps.save_interval == 0 and batch_num != 0):
                self.save_checkpoint(position)

            if hps.skip_loop_body:
                continue

            lr_index = util.greatest_lower_bound(sorted_lr_steps, global_step)
            ss.update_learning_rate(
                self.learning_rates[sorted_lr_steps[lr_index]])
            # if ss.data.global_step in self.learning_rates:
            # ss.update_learning_rate(self.learning_rates[ss.data.global_step])

            if ss.model.bn_type == 'vae' and ss.step in self.anneal_schedule:
                ss.model.objective.update_anneal_weight(
                    self.anneal_schedule[ss.data.global_step])

            ss.optim.zero_grad()
            quant, self.target, loss = self.state.model.run(
                wav, mel, voice, jitter)
            self.probs = self.softmax(quant)
            self.mel_enc_input = mel
            # print(f'after model.run', file=stderr)
            # stderr.flush()
            loss.backward()

            # print(f'after loss.backward()', file=stderr)
            # stderr.flush()

            if batch_num % hps.progress_interval == 0:
                pars_copy = [p.data.clone() for p in ss.model.parameters()]

            # print(f'after pars_copy', file=stderr)
            # stderr.flush()

            if self.is_tpu:
                xm.optimizer_step(ss.optim)
            else:
                ss.optim.step()

            ss.optim_step += 1

            if ss.model.bn_type == 'vqvae-ema' and ss.data.global_step == 10000:
                ss.model.bottleneck.update_codebook()

            tprb_m = self.avg_prob_target()

            if batch_num % hps.progress_interval == 0:
                iterator = zip(pars_copy, ss.model.named_parameters())
                uw_ratio = {
                    np[0]: t.norm(c - np[1].data) / c.norm()
                    for c, np in iterator
                }

                writer_stats.update({'uwr': uw_ratio})

                if self.is_tpu:
                    count = torch_xla._XLAC._xla_get_replication_devices_count(
                    )
                    loss_red, tprb_red = xm.all_reduce('sum', [loss, tprb_m],
                                                       scale=1.0 / count)
                    # loss_red = xm.all_reduce('all_loss', loss, reduce_mean)
                    # tprb_red = xm.all_reduce('all_tprb', tprb_m, reduce_mean)
                else:
                    loss_red = loss
                    tprb_red = tprb_m

                writer_stats.update({
                    'loss_r': loss_red,
                    'tprb_r': tprb_red,
                    'optim_step': ss.optim_step
                })

                current_stats.update({
                    'optim_step': ss.optim_step,
                    'gstep': global_step,
                    # 'gstep': ss.data.global_step,
                    'epoch': position[0],
                    'step': position[1],
                    # 'loss': loss,
                    'lrate': ss.optim.param_groups[0]['lr'],
                    # 'tprb_m': tprb_m,
                    # 'pk_d_m': avg_peak_dist
                })
                current_stats.update(ss.model.objective.metrics)

                if ss.model.bn_type in ('vae'):
                    current_stats['free_nats'] = ss.model.objective.free_nats
                    current_stats['anneal_weight'] = \
                            ss.model.objective.anneal_weight.item()

                if ss.model.bn_type in ('vqvae', 'vqvae-ema', 'ae', 'vae'):
                    current_stats.update(ss.model.encoder.metrics)

                if self.is_tpu:
                    xm.add_step_closure(self.train_update,
                                        args=(writer_stats, current_stats))
                else:
                    self.train_update(writer_stats, current_stats)

                # if not self.is_tpu or xm.is_master_ordinal():
                # if batch_num in range(25, 50) or batch_num in range(75, 100):
                stderr.flush()
                elapsed = time.time() - start_time
Esempio n. 27
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.data_parallel_world_size > 1
                    and hasattr(self.model, "no_sync")
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args.distributed_world_size == 1:
                        return None
                else:
                    raise e

            if self.tpu and i < len(samples) - 1:
                # tpu-comment: every XLA operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass without a backward pass
                import torch_xla.core.xla_model as xm
                xm.mark_step()

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            train_time = self._local_cumulative_training_time()
            logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs(
                logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch,
            )
            self._cumulative_training_time = total_train_time / self.data_parallel_world_size

        if hasattr(self.model, 'all_reduce'):
            self.model.all_reduce()

        overflow = False
        try:
            if self.tpu and self.data_parallel_world_size > 1:
                import torch_xla.core.xla_model as xm
                gradients = xm._fetch_gradients(self.optimizer.optimizer)
                xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (# GPUs / sample_size) since DDP
                # already normalizes by the number of GPUs. Thus we get
                # (sum_of_gradients / sample_size).
                if not self.args.use_bmuf:
                    self.optimizer.multiply_grads(self.data_parallel_world_size / sample_size)
                elif sample_size > 0:  # BMUF needs to check sample size
                    num = self.data_parallel_world_size if self._sync_stats() else 1
                    self.optimizer.multiply_grads(num / sample_size)

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if (
                not self.args.use_bmuf
                and self.args.distributed_wrapper != 'SlowMo'
                and not self.tpu
            ):
                self._check_grad_norms(grad_norm)

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()
        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            with NanDetector(self.model):
                self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
                    ignore_grad=False
                )
            raise
        except OverflowError as e:
            overflow = True
            logger.info("NOTE: overflow detected, " + str(e))
            grad_norm = torch.tensor(0.).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step
        if hasattr(self.model, 'perform_additional_optimizer_actions'):
            if hasattr(self.optimizer, 'fp32_params'):
                self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params)
            else:
                self.model.perform_additional_optimizer_actions(self.optimizer.optimizer)

        if not overflow or self.args.distributed_wrapper == 'SlowMo':
            self.set_num_updates(self.get_num_updates() + 1)

            if self.tpu:
                # mark step on TPUs
                import torch_xla.core.xla_model as xm
                xm.mark_step()

                # only log stats every log_interval steps
                # this causes wps to be misreported when log_interval > 1
                logging_output = {}
                if self.get_num_updates() % self.args.log_interval == 0:
                    logging_output = self._reduce_and_log_stats(
                        logging_outputs, sample_size, grad_norm,
                    )

                # log whenever there's an XLA compilation, since these
                # slow down training and may indicate opportunities for
                # optimization
                self._check_xla_compilation()
            else:
                # log stats
                logging_output = self._reduce_and_log_stats(
                    logging_outputs, sample_size, grad_norm,
                )

                # clear CUDA cache to reduce memory fragmentation
                if (
                    self.cuda
                    and self.args.empty_cache_freq > 0
                    and (
                        (self.get_num_updates() + self.args.empty_cache_freq - 1)
                        % self.args.empty_cache_freq
                    ) == 0
                ):
                    torch.cuda.empty_cache()

        if self.args.fp16:
            metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)

        metrics.log_stop_time("train_wall")

        return logging_output
Esempio n. 28
0
 def all_reduce(self, collectiveArgs, retFlag=False):
     retObj = xm.all_reduce(collectiveArgs.op, [collectiveArgs.ipTensor])
     if collectiveArgs.asyncOp:
         collectiveArgs.waitObj.append(retObj)
     if retFlag:
         return retObj
Esempio n. 29
0
 def backward(ctx, grad_output):
     dim = ctx.dim
     all_grad_output = xm.all_reduce(xm.REDUCE_SUM, grad_output)
     return all_grad_output.select(dim, xm.get_ordinal()), None
Esempio n. 30
0
 def forward(ctx, input, reduce_type, scale, groups):
     ctx.reduce_type = reduce_type
     ctx.scale = scale
     output = xm.all_reduce(reduce_type, input, scale=scale, groups=groups)
     ctx.save_for_backward(input, output)
     return output