Exemple #1
0
    def _reduce_and_log_stats(self,
                              logging_outputs,
                              sample_size,
                              grad_norm=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args['optimization']['clip_norm'] > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args['optimization']['clip_norm'],
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
Exemple #2
0
 def reduce_metrics(logging_outputs) -> None:
     """Aggregate logging outputs from data parallel training."""
     loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
     accuracy_sum = sum(log.get('accuracy', 0) for log in logging_outputs)
     sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
     metrics.log_scalar('loss', loss_sum / sample_size, sample_size, round=3)
     metrics.log_scalar('accuracy', accuracy_sum / sample_size, sample_size, round=3)
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        # ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
Exemple #4
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        if self.args['task']['eval_bleu']:
            def sum_logs(key):
                return sum(log.get(key, 0) for log in logging_outputs)

            metrics.log_scalar('bleu', sum_logs('bleu'))
Exemple #5
0
 def set_num_updates(self, num_updates):
     """Set the number of parameters updates."""
     self._num_updates = num_updates
     # self.lr_step_update() # TODO
     metrics.log_scalar("num_updates",
                        self._num_updates,
                        weight=0,
                        priority=200)
Exemple #6
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        with torch.no_grad():
            if self.args['task']['eval_mrr']:
                def sum_logs(key):
                    return sum(log.get(key, 0) for log in logging_outputs)

                metrics.log_scalar('mrr', sum_logs('mrr'))
                metrics.log_scalar('sample_size', sum_logs('sample_size'))
Exemple #7
0
 def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
     """Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         'Criterions should implement the reduce_metrics API. '
         'Falling back to deprecated aggregate_logging_outputs API.'
     )
     agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
     for k, v in agg_logging_outputs.items():
         if k in {'nsentences', 'ntokens', 'sample_size'}:
             continue
         metrics.log_scalar(k, v)
Exemple #8
0
 def reduce_metrics(logging_outputs) -> None:
     """Aggregate logging outputs from data parallel training."""
     loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
     sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
     # metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
     metrics.log_scalar('loss',
                        loss_sum / sample_size,
                        sample_size,
                        round=3)
     metrics.log_derived(
         'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
Exemple #9
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        if self.args['task']['eval_bleu']:

            if self.args['task']['eval_with_sacrebleu']:

                def sum_logs(key):
                    import torch
                    result = sum(log.get(key, 0) for log in logging_outputs)
                    if torch.is_tensor(result):
                        result = result.cpu()
                    return result

                counts, totals = [], []
                for i in range(EVAL_BLEU_ORDER):
                    counts.append(sum_logs('_bleu_counts_' + str(i)))
                    totals.append(sum_logs('_bleu_totals_' + str(i)))

                if max(totals) > 0:
                    # log counts as numpy arrays -- log_scalar will sum them correctly
                    metrics.log_scalar('_bleu_counts', np.array(counts))
                    metrics.log_scalar('_bleu_totals', np.array(totals))
                    metrics.log_scalar('_bleu_sys_len',
                                       sum_logs('_bleu_sys_len'))
                    metrics.log_scalar('_bleu_ref_len',
                                       sum_logs('_bleu_ref_len'))

                    def compute_bleu(meters):
                        import inspect
                        import sacrebleu
                        fn_sig = inspect.getfullargspec(
                            sacrebleu.compute_bleu)[0]
                        if 'smooth_method' in fn_sig:
                            smooth = {'smooth_method': 'exp'}
                        else:
                            smooth = {'smooth': 'exp'}
                        bleu = sacrebleu.compute_bleu(
                            correct=meters['_bleu_counts'].sum,
                            total=meters['_bleu_totals'].sum,
                            sys_len=meters['_bleu_sys_len'].sum,
                            ref_len=meters['_bleu_ref_len'].sum,
                            **smooth)
                        return round(bleu.score, 6)

                    metrics.log_derived('bleu', compute_bleu)
            else:

                def sum_logs(key):
                    return sum(log.get(key, 0) for log in logging_outputs)

                metrics.log_scalar('bleu', sum_logs('bleu'), round=6)
Exemple #10
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
        match = sum(log.get('match', 0) for log in logging_outputs)
        total = sum(log.get('total', 0) for log in logging_outputs)

        metrics.log_scalar('match', value=match, round=0)
        metrics.log_scalar('total', value=total, round=0)
        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        if sample_size != ntokens:
            metrics.log_scalar('nll_loss',
                               loss_sum / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_derived(
                'ppl',
                lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
        else:
            metrics.log_derived(
                'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
Exemple #11
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)

        def sum_logs(key):
            return sum(log.get(key, 0) for log in logging_outputs)

        ntokens = sum_logs('ntokens')
        accuracy_sum = sum_logs('accuracy')
        metrics.log_scalar('accuracy',
                           accuracy_sum / ntokens,
                           ntokens,
                           round=6)

        if self.args['task']['eval_mrr']:
            mrr_sum = sum_logs('mrr')
            metrics.log_scalar('mrr', mrr_sum / ntokens, ntokens, round=6)
Exemple #12
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = utils.item(
            sum(log.get('loss', 0) for log in logging_outputs))
        nll_loss_sum = utils.item(
            sum(log.get('nll_loss', 0) for log in logging_outputs))
        alignment_loss_sum = utils.item(
            sum(log.get('alignment_loss', 0) for log in logging_outputs))
        ntokens = utils.item(
            sum(log.get('ntokens', 0) for log in logging_outputs))
        sample_size = utils.item(
            sum(log.get('sample_size', 0) for log in logging_outputs))

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_scalar('nll_loss',
                           nll_loss_sum / ntokens / math.log(2),
                           ntokens,
                           round=3)
        metrics.log_scalar('alignment_loss',
                           alignment_loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        metrics.log_derived(
            'ppl', lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
Exemple #13
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = NccTask.aggregate_logging_outputs
        self_func = getattr(self, "aggregate_logging_outputs").__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                "Tasks should implement the reduce_metrics API. "
                "Falling back to deprecated aggregate_logging_outputs API."
            )
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion
            )
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any("ntokens" in log for log in logging_outputs):
            warnings.warn(
                "ntokens not found in Criterion logging outputs, cannot log wpb or wps"
            )
        else:
            ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
            metrics.log_scalar("wpb", ntokens, priority=180, round=1)
            metrics.log_speed("wps", ntokens, priority=90, round=1)

        if not any("nsentences" in log for log in logging_outputs):
            warnings.warn(
                "nsentences not found in Criterion logging outputs, cannot log bsz"
            )
        else:
            nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
            metrics.log_scalar("bsz", nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
Exemple #14
0
 def set_lr(self, new_lr):
     """Set the learning rate by users"""
     self.lr_scheduler.lr = new_lr
     self.lr_scheduler.optimizer.set_lr(new_lr)
     metrics.log_scalar("lr", new_lr, weight=0, priority=300, round=6)
     return self.lr_scheduler.optimizer.get_lr()
Exemple #15
0
 def lr_step_update(self):
     """Update the learning rate after each update."""
     new_lr = self.lr_scheduler.step_update(self.get_num_updates())
     metrics.log_scalar("lr", new_lr, weight=0, priority=300, round=6)
     return new_lr
Exemple #16
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample, is_dummy_batch = self._prepare_sample(sample)

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args['distributed_training']['distributed_world_size']
                        > 1 and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    LOGGER.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args['distributed_training'][
                            'distributed_world_size'] == 1:
                        return None
                else:
                    raise e

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0  # multiply by 0 to preserve device

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        overflow = False
        try:
            with torch.autograd.profiler.record_function("reduce-grads"):
                # reduce gradients across workers
                self.optimizer.all_reduce_grads(self.model)
                if utils.has_parameters(self.criterion):
                    self.optimizer.all_reduce_grads(self.criterion)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (data_parallel_size / sample_size) since
                # DDP normalizes by the number of data parallel workers for
                # improved fp16 precision.
                # Thus we get (sum_of_gradients / sample_size) at the end.
                # In case of fp16, this step also undoes loss scaling.
                # (Debugging note: Some optimizers perform this scaling on the
                # fly, so inspecting model.parameters() or optimizer.params may
                # still show the original, unscaled gradients.)
                num = (
                    self.args['distributed_training']['distributed_world_size']
                    if not self.args['optimization']['use_bmuf']
                    or self._sync_stats() else 1)
                self.optimizer.multiply_grads(num / (sample_size or 1.0))

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(
                    self.args['optimization']['clip_norm'])

            # check that grad norms are consistent across workers
            if not self.args['optimization']['use_bmuf']:
                self._check_grad_norms(grad_norm)
            if not torch.isfinite(grad_norm).all():
                if self.args['common'].get('amp', False):
                    overflow = True
                else:
                    raise FloatingPointError("gradients are Nan/Inf")

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()
                if self.args['common'].get('amp', False) and overflow:
                    if self._amp_retries == self.args['common'][
                            'amp_batch_retries']:
                        LOGGER.info("AMP: skipping this batch.")
                        self._amp_retries = 0
                    else:
                        self._amp_retries += 1
                        return self.train_step(
                            samples,
                            raise_oom)  # recursion to feed in same batch

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            self.zero_grad()
            with NanDetector(self._model):
                for _, sample in enumerate(samples):
                    sample, _ = self._prepare_sample(sample)
                    self.task.train_step(sample,
                                         self.model,
                                         self.criterion,
                                         self.optimizer,
                                         self.get_num_updates(),
                                         ignore_grad=False)
            raise
        except OverflowError as e:
            overflow = True
            LOGGER.info(
                f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
            )
            grad_norm = torch.tensor(0.0).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                LOGGER.error("OOM during optimization, irrecoverable")
            raise e

        logging_output = None
        if not overflow:
            self.set_num_updates(self.get_num_updates() + 1)

            if self.cuda and self.cuda_env is not None:
                # log minimum free memory over the iteration
                gb_used = torch.cuda.max_memory_allocated(
                ) / 1024 / 1024 / 1024
                torch.cuda.reset_peak_memory_stats()
                gb_free = self.cuda_env.total_memory_in_GB - gb_used
                metrics.log_scalar("gb_free",
                                   gb_free,
                                   priority=1500,
                                   round=1,
                                   weight=0)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs,
                sample_size,
                grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.cuda and self.args['common']['empty_cache_freq'] > 0
                    and ((self.get_num_updates() +
                          self.args['common']['empty_cache_freq'] - 1) %
                         self.args['common']['empty_cache_freq']) == 0):
                torch.cuda.empty_cache()

        if self.args['common']['fp16'] or self.args['common'].get(
                'amp', False):
            metrics.log_scalar(
                "loss_scale",
                (self.optimizer.scaler.loss_scale
                 if self.args['common']['fp16'] else
                 self.optimizer.scaler.get_scale()),
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")
        return logging_output