Exemple #1
0
    def _finish_update(self):
        if self.training_config.clip_gradients:
            clip_gradients(
                self.model,
                self.num_updates,
                self.logistics_callback.tb_writer,
                self.config,
                scale=self.scaler.get_scale(),
            )
        if is_xla():
            import torch_xla.core.xla_model as xm

            # Assumes no model parallel
            xm.reduce_gradients(self.optimizer)

        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.num_updates += 1
        self.profile("Finished update")
Exemple #2
0
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, 'no_sync')
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)

                    if self.fast_stat_sync:
                        self._all_reduce_list[0] += sample_size
                        self._all_reduce_list[1] += logging_output.get(
                            'nsentences', 0.0)
                        self._all_reduce_list[2] += logging_output.get(
                            'loss', 0.0)
                        self._all_reduce_list[3] += logging_output.get(
                            'nll_loss', 0.0)
                        self._all_reduce_list[4] += logging_output.get(
                            'ntokens', 0.0)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    msg = ('| WARNING: ran out of memory with exception: ' +
                           '{};'.format(e) + '\n Skipping batch')
                    # TODO: print should really go to logger, this print goes
                    # to stderr, which is buffered, which in many cases is not
                    # printed out if another exception happens.
                    # NB(jerry): added a flush to mitigate this
                    print(msg, file=sys.stderr)
                    if torch.cuda.is_available() and hasattr(
                            torch.cuda, "memory_summary"):
                        for device_idx in range(torch.cuda.device_count()):
                            print(torch.cuda.memory_summary(device=device_idx),
                                  file=sys.stderr)
                    sys.stderr.flush()

                    if raise_oom:
                        raise ValueError(msg)
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

            if self.xla and len(samples) > 1:
                # tpu-comment: every xla operation before marking step is
                # appended to the IR graph, and processing too many batches
                # before marking step can lead to OOM errors.
                # To handle gradient accumulation use case, we explicitly
                # mark step here for every forward pass if we accumulate grads
                xm.mark_step()

            if self.fast_stat_sync:
                self._all_reduce_list[5] += ooms

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.fast_stat_sync:
            # rework all_gather_list
            all_reduce_list_tensor = torch.cuda.DoubleTensor(
                self._all_reduce_list)
            if self._sync_stats():
                torch.distributed.all_reduce(all_reduce_list_tensor)
            # Normalize loss and nll_loss by "sample_size"
            # and convert to log base 2
            all_reduce_list_tensor[2:4].div_(
                (all_reduce_list_tensor[0:1] *
                 torch.log(torch.cuda.DoubleTensor([2]))))
            self._all_reduce_list = all_reduce_list_tensor.tolist()
            logging_output = {}
            [
                sample_size,
                logging_output['nsentences'],
                logging_output['loss'],
                logging_output['nll_loss'],
                logging_output['ntokens'],
                ooms,
            ] = self._all_reduce_list
        elif self._sync_stats():
            logging_outputs, sample_sizes, ooms, prev_norms = \
                zip(*distributed_utils.all_gather_list(
                    [logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
                ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

            if not self.args.use_bmuf:
                assert (
                    all(norm == prev_norms[0] for norm in prev_norms) or all(
                        math.isnan(norm) or math.isinf(norm)
                        for norm in prev_norms)
                ), 'Fatal error: gradients are inconsistent between workers'

        self.meters['oom'].update(ooms, len(samples))
        if ooms == self.args.distributed_world_size * len(samples):
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        if not self.fast_stat_sync:
            # aggregate logging outputs and sample sizes
            logging_output = self.task.aggregate_logging_outputs(
                logging_outputs, self.get_criterion())
            sample_size = self.task.grad_denom(sample_sizes,
                                               self.get_criterion())

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception(
                ('Please update the {}.aggregate_logging_outputs() method to '
                 'return ntokens and nsentences').format(
                     self.task.__class__.__name__))

        try:
            # normalize grads by sample size
            if sample_size > 0:
                self.optimizer.multiply_grads(
                    self.args.distributed_world_size / float(sample_size))
            if self.xla:
                # tpu-comment: for tpu training, we need to explicitly reduce
                #   gradients here
                xm.reduce_gradients(self.optimizer)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            self._prev_grad_norm = grad_norm

            # take an optimization step

            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            # tpu-comment: the comparison below introduces too many .item()
            # calls and slows down tpu
            self.meters['clip'].update(
                0.
                #1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['train_loss'].update(logging_output.get('loss', 0),
                                             sample_size)
            if 'train_acc' in self.meters:
                self.meters['train_acc'].update(logging_output.get('acc', 0),
                                                sample_size)

            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(
                    logging_output.get('nll_loss', 0), ntokens)

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.clear_buffered_stats()
        self.meters['train_wall'].stop()

        return logging_output