Ejemplo n.º 1
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = FairseqTask.aggregate_logging_outputs
        self_func = getattr(self, "aggregate_logging_outputs").__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                "Tasks should implement the reduce_metrics API. "
                "Falling back to deprecated aggregate_logging_outputs API."
            )
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion
            )
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any("ntokens" in log for log in logging_outputs):
            warnings.warn(
                "ntokens not found in Criterion logging outputs, cannot log wpb or wps"
            )
        else:
            ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
            metrics.log_scalar("wpb", ntokens, priority=180, round=1)
            metrics.log_speed("wps", ntokens, priority=90, round=1)

        if not any("nsentences" in log for log in logging_outputs):
            warnings.warn(
                "nsentences not found in Criterion logging outputs, cannot log bsz"
            )
        else:
            nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
            metrics.log_scalar("bsz", nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
Ejemplo n.º 2
0
    def reduce_metrics(self, logging_outputs, criterion):
        if not any('ntokens' in log for log in logging_outputs):
            warnings.warn('ntokens not found in Criterion logging outputs, cannot log wpb or wps')
        else:
            ntokens = utils.item(sum(log.get('ntokens', 0) for log in logging_outputs))
            metrics.log_scalar('wpb', ntokens, priority=180, round=1)
            # TODO(urikz): Latest version of fairseq also has additional argument "ignore_first"
            metrics.log_speed('wps', ntokens, priority=90, round=1)

        if not any('nsentences' in log for log in logging_outputs):
            warnings.warn('nsentences not found in Criterion logging outputs, cannot log bsz')
        else:
            nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs))
            metrics.log_scalar('ns', nsentences, priority=190, round=1)

        if not any('sample_size' in log for log in logging_outputs):
            warnings.warn('sample_size not found in Criterion logging outputs, cannot log bsz')
        else:
            sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs))
            metrics.log_scalar('bsz', sample_size, priority=190, round=1)

        if 'n_bor_instances' in logging_outputs[0].keys():
            n_bor_instances = utils.item(sum(log.get('n_bor_instances', 0) for log in logging_outputs))
            metrics.log_scalar('bsz_bor', n_bor_instances, priority=195, round=1)
        if 'ntokens_AB' in logging_outputs[0].keys():
            ntokens_AB = utils.item(sum(log.get('ntokens_AB', 0) for log in logging_outputs))
            metrics.log_scalar('wpb_AB', ntokens_AB, priority=200, round=1)
        if 'ntokens_mem' in logging_outputs[0].keys():
            ntokens_mem = utils.item(sum(log.get('ntokens_mem', 0) for log in logging_outputs))
            metrics.log_scalar('wpb_mem', ntokens_mem, priority=200, round=1)

        criterion.__class__.reduce_metrics(logging_outputs, self.split)
Ejemplo n.º 3
0
    def reduce_metrics(self, logging_outputs, criterion):
        keys = set([
            k for logging_output in logging_outputs
            for k in logging_output.keys()
        ])
        for key in keys:
            if key.endswith('ntokens'):
                prefix = key[:-len('ntokens')]
                ntokens = utils.item(
                    sum(log.get(key, 0) for log in logging_outputs))
                metrics.log_scalar(prefix + 'wpb',
                                   ntokens,
                                   priority=80,
                                   round=1)
                metrics.log_speed(prefix + 'wps',
                                  ntokens,
                                  priority=50,
                                  round=1)
            elif key.endswith('nsentences'):
                prefix = key[:-len('nsentences')]
                nsentences = utils.item(
                    sum(log.get(key, 0) for log in logging_outputs))
                metrics.log_scalar(prefix + 'ns',
                                   nsentences,
                                   priority=90,
                                   round=1)
            elif key.endswith('sample_size'):
                prefix = key[:-len('sample_size')]
                sample_size = utils.item(
                    sum(log.get(key, 0) for log in logging_outputs))
                metrics.log_scalar(prefix + 'bsz',
                                   sample_size,
                                   priority=190,
                                   round=1)

        criterion.reduce_metrics(logging_outputs, self.split)
Ejemplo n.º 4
0
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_size += sample_size_i
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    print(
                        "| WARNING: attempting to recover from OOM in forward/backward pass",
                        file=sys.stderr,
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, sample_size, ooms = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ooms,
            )

        metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3)
        if ooms == self.args.distributed_world_size * len(samples):
            print("| WARNING: OOM in all workers, skipping update")
            self.zero_grad()
            return None

        try:
            # normalize grads by sample size
            if not self.args.use_bmuf:
                # multiply gradients by (# GPUs / sample_size) since DDP
                # already normalizes by the number of GPUs. Thus we get
                # (sum_of_gradients / sample_size).
                self.optimizer.multiply_grads(
                    self.args.distributed_world_size / sample_size)
            elif sample_size > 0:
                # during non-sync gradients are divided by
                # sample_size whereas during sync (while calculating
                # global model): sync accumulate gradients and
                # divided by #GPUs and now multiply by #GPUs/#sample_size
                if self._sync_stats():
                    self.optimizer.multiply_grads(
                        self.args.distributed_world_size / sample_size)
                else:
                    self.optimizer.multiply_grads(1 / sample_size)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size)
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
            metrics.log_scalar(
                "clip",
                100 if grad_norm > self.args.clip_norm > 0 else 0,
                priority=500,
                round=1,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except OverflowError as e:
            print("| WARNING: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                print("| ERROR: OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        self.clear_buffered_stats()
        metrics.log_stop_time("train_wall")

        return logging_output
Ejemplo n.º 5
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                if not is_dummy_batch:
                    sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    logger.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3)
        if ooms == self.args.distributed_world_size * len(samples):
            logger.warning("OOM in all workers, skipping update")
            self.zero_grad()
            return None

        try:
            # normalize grads by sample size
            if sample_size > 0:
                if self._sync_stats():
                    # multiply gradients by (# GPUs / sample_size) since DDP
                    # already normalizes by the number of GPUs. Thus we get
                    # (sum_of_gradients / sample_size).
                    self.optimizer.multiply_grads(
                        self.args.distributed_world_size / sample_size)
                else:
                    self.optimizer.multiply_grads(1 / sample_size)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self.get_num_updates())

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size)
            metrics.log_speed("ups",
                              1.,
                              ignore_first=10,
                              priority=100,
                              round=2)
            metrics.log_scalar("gnorm",
                               utils.item(grad_norm),
                               priority=400,
                               round=3)
            metrics.log_scalar(
                "clip",
                100 if grad_norm > self.args.clip_norm > 0 else 0,
                priority=500,
                round=1,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except OverflowError as e:
            logger.info("NOTE: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                logger.error("OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        metrics.log_stop_time("train_wall")

        return logging_output