示例#1
0
    def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
        if grad_norm is not None:
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            if self.args['optimization']['clip_norm'] > 0:
                metrics.log_scalar(
                    "clip",
                    torch.where(
                        grad_norm > self.args['optimization']['clip_norm'],
                        grad_norm.new_tensor(100),
                        grad_norm.new_tensor(0),
                    ),
                    priority=500,
                    round=1,
                )

        with metrics.aggregate() as agg:
            if logging_outputs is not None:
                self.task.reduce_metrics(logging_outputs, self.get_criterion())

            # support legacy interface
            logging_output = agg.get_smoothed_values()
            logging_output["sample_size"] = sample_size
            for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
                if key_to_delete in logging_output:
                    del logging_output[key_to_delete]
            return logging_output
示例#2
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        if self.args['task']['eval_bleu']:

            def sum_logs(key):
                return sum(log.get(key, 0) for log in logging_outputs)

            metrics.log_scalar('bleu', sum_logs('bleu'))
示例#3
0
 def reduce_metrics(logging_outputs) -> None:
     """Aggregate logging outputs from data parallel training."""
     loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
     sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
     metrics.log_scalar('loss',
                        loss_sum / sample_size,
                        sample_size,
                        round=3)
示例#4
0
 def set_num_updates(self, num_updates):
     """Set the number of parameters updates."""
     self.num_updates = num_updates
     self.lr_step_update()
     metrics.log_scalar("num_updates",
                        self.num_updates,
                        weight=0,
                        priority=200)
示例#5
0
 def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
     """Aggregate logging outputs from data parallel training."""
     utils.deprecation_warning(
         'Criterions should implement the reduce_metrics API. '
         'Falling back to deprecated aggregate_logging_outputs API.'
     )
     agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
     for k, v in agg_logging_outputs.items():
         if k in {'nsentences', 'ntokens', 'sample_size'}:
             continue
         metrics.log_scalar(k, v)
示例#6
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)
        with torch.no_grad():
            if self.args['task']['eval_mrr']:

                def mean_logs(key):
                    return sum(
                        log.get(key, 0)
                        for log in logging_outputs) / len(logging_outputs)

                metrics.log_scalar('mrr', mean_logs('mrr'))
示例#7
0
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
        ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
        sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

        metrics.log_scalar('loss',
                           loss_sum / sample_size / math.log(2),
                           sample_size,
                           round=3)
        if sample_size != ntokens:
            metrics.log_scalar('nll_loss',
                               loss_sum / ntokens / math.log(2),
                               ntokens,
                               round=3)
            metrics.log_derived(
                'ppl',
                lambda meters: utils.get_perplexity(meters['nll_loss'].avg))
        else:
            metrics.log_derived(
                'ppl', lambda meters: utils.get_perplexity(meters['loss'].avg))
示例#8
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = NccTask.aggregate_logging_outputs
        self_func = getattr(self, "aggregate_logging_outputs").__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                "Tasks should implement the reduce_metrics API. "
                "Falling back to deprecated aggregate_logging_outputs API.")
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion)
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any("ntokens" in log for log in logging_outputs):
            warnings.warn(
                "ntokens not found in Criterion logging outputs, cannot log wpb or wps"
            )
        else:
            ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
            metrics.log_scalar("wpb", ntokens, priority=180, round=1)
            metrics.log_speed("wps", ntokens, priority=90, round=1)

        if not any("nsentences" in log for log in logging_outputs):
            warnings.warn(
                "nsentences not found in Criterion logging outputs, cannot log bsz"
            )
        else:
            nsentences = sum(
                log.get("nsentences", 0) for log in logging_outputs)
            metrics.log_scalar("bsz", nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
示例#9
0
    def reduce_metrics(self, logging_outputs, criterion):
        super().reduce_metrics(logging_outputs, criterion)

        def sum_logs(key):
            return sum(log.get(key, 0) for log in logging_outputs)

        if self.args['task']['eval_accuracy']:
            if sum_logs(
                    'accuracy'
            ) > 0:  # ==0: no accuracy items in the logging outputs, it means the training stage
                metrics.log_scalar('accuracy', sum_logs('accuracy'))
        if self.args['task']['eval_last_accuracy']:
            if sum_logs(
                    'last_accuracy'
            ) > 0:  # ==0: no accuracy items in the logging outputs, it means the training stage
                metrics.log_scalar('last_accuracy', sum_logs('last_accuracy'))
        if self.args['task']['eval_mrr']:
            if sum_logs('mrr') > 0:
                metrics.log_scalar('mrr', sum_logs('mrr'))
示例#10
0
 def lr_step_update(self):
     """Update the learning rate after each update."""
     new_lr = self.lr_scheduler.step_update(self.get_num_updates())
     metrics.log_scalar("lr", new_lr, weight=0, priority=300)
     return new_lr
示例#11
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args['distributed_training']['distributed_world_size'] > 1
                    and hasattr(self.model, "no_sync")
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss
                # self.model.train()
                # self.set_num_updates(self.task.get_num_updates())
                # loss, sample_size, logging_output = self.criterion(self.model, sample)
                # # if ignore_grad:
                # #     loss *= 0
                # print('loss: ', loss.item())
                # # optimizer.backward(loss)
                # loss.backward()
                # # self.optimizer.step()

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args['distributed_training']['distributed_world_size'] == 1:
                        return None
                else:
                    raise e

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)
        if is_dummy_batch:
            sample_size *= 0.  # multiply by 0 to preserve device

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
                logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
            )

        try:
            # multiply gradients by (# GPUs / sample_size) since DDP
            # already normalizes by the number of GPUs. Thus we get
            # (sum_of_gradients / sample_size).
            if not self.args['optimization']['use_bmuf']:
                self.optimizer.multiply_grads(
                    self.args['distributed_training']['distributed_world_size'] / sample_size
                )
            elif sample_size > 0:  # BMUF needs to check sample size
                num = self.args['distributed_training']['distributed_world_size'] if self._sync_stats() else 1
                self.optimizer.multiply_grads(num / sample_size)

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(self.args['optimization']['clip_norm'])

            # check that grad norms are consistent across workers
            if not self.args['optimization']['use_bmuf']:
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
            # TODO: Warning, the learning rate will be updated by its scheduler here, commented currently
            # must set. plz, update learning rate by other methods
            self.set_num_updates(self.get_num_updates() + 1)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size, grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (
                self.args['common']['empty_cache_freq'] > 0
                and (
                (self.get_num_updates() + self.args['common']['empty_cache_freq'] - 1)
                % self.args['common']['empty_cache_freq']
            ) == 0
                and torch.cuda.is_available()
                and not self.args['common']['cpu']
            ):
                torch.cuda.empty_cache()

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print out where it fails
            with NanDetector(self.model):
                self.task.train_step(
                    sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
                    ignore_grad=False
                )
            raise
        except OverflowError as e:
            logger.info("NOTE: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        if self.args['common']['fp16']:
            metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=4)

        metrics.log_stop_time("train_wall")

        return logging_output