Exemplo n.º 1
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = FairseqTask.aggregate_logging_outputs
        self_func = getattr(self, 'aggregate_logging_outputs').__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                'Tasks should implement the reduce_metrics API. '
                'Falling back to deprecated aggregate_logging_outputs API.')
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion)
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any('ntokens' in log for log in logging_outputs):
            warnings.warn(
                'ntokens not found in Criterion logging outputs, cannot log wpb or wps'
            )
        else:
            ntokens = utils.item(
                sum(log.get('ntokens', 0) for log in logging_outputs))
            metrics.log_scalar('wpb', ntokens, priority=180, round=1)
            metrics.log_speed('wps', ntokens, priority=90, round=1)

        if not any('nsentences' in log for log in logging_outputs):
            warnings.warn(
                'nsentences not found in Criterion logging outputs, cannot log bsz'
            )
        else:
            nsentences = utils.item(
                sum(log.get('nsentences', 0) for log in logging_outputs))
            metrics.log_scalar('bsz', nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
Exemplo n.º 2
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model,
                                           net_output,
                                           sample,
                                           reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }

        alignment_loss = None

        # Compute alignment loss only for training set and non dummy batches.
        if 'alignments' in sample and sample['alignments'] is not None:
            alignment_loss = self.compute_alignment_loss(sample, net_output)

        if alignment_loss is not None:
            logging_output['alignment_loss'] = utils.item(alignment_loss.data)
            loss += self.alignment_lambda * alignment_loss

        return loss, sample_size, logging_output
Exemplo n.º 3
0
    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = '{}.layers.{}.layer_norms.{}.{}'.format(
                        name, i, old, m)
                    if k in state_dict:
                        state_dict['{}.layers.{}.{}.{}'.format(
                            name, i, new, m)] = state_dict[k]
                        del state_dict[k]

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])

        return state_dict
Exemplo n.º 4
0
    def forward(self, model, sample, reduce=True):
        """Compute ranking loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        assert (
            hasattr(model, 'classification_heads')
            and self.args.ranking_head_name in model.classification_heads
        ), 'model must provide sentence ranking head for --criterion=sentence_ranking'

        scores = []
        for idx in range(self.args.num_classes):
            score, _ = model(
                **sample['net_input{idx}'.format(idx=idx + 1)],
                classification_head_name=self.args.ranking_head_name,
            )
            scores.append(score)

        logits = torch.cat(scores, dim=1)
        sample_size = logits.size(0)

        if 'target' in sample:
            targets = model.get_targets(sample, [logits]).view(-1)
            loss = F.nll_loss(
                F.log_softmax(logits, dim=-1, dtype=torch.float32),
                targets,
                reduction='sum',
            )
        else:
            targets = None
            loss = torch.tensor(0.0, requires_grad=True)

        if self.prediction_h is not None:
            preds = logits.argmax(dim=1)
            for i, (id, pred) in enumerate(
                    zip(sample['id'].tolist(), preds.tolist())):
                if targets is not None:
                    label = targets[i].item()
                    print('{}\t{}\t{}'.format(id, pred, label),
                          file=self.prediction_h)
                else:
                    print('{}\t{}'.format(id, pred), file=self.prediction_h)

        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample_size,
            'sample_size': sample_size,
        }
        if targets is not None:
            logging_output['ncorrect'] = (logits.argmax(
                dim=1) == targets).sum()

        return loss, sample_size, logging_output
Exemplo n.º 5
0
 def upgrade_state_dict(self, state_dict):
     if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
         # old models use incorrect weight norm dimension
         for i, conv in enumerate(self.convolutions):
             # reconfigure weight norm
             nn.utils.remove_weight_norm(conv)
             self.convolutions[i] = nn.utils.weight_norm(conv, dim=0)
         state_dict['decoder.version'] = torch.Tensor([1])
     return state_dict
Exemplo n.º 6
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model,
                                           net_output,
                                           sample,
                                           reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
Exemplo n.º 7
0
    def generate(
        self,
        tokenized_sentences: List[torch.LongTensor],
        beam: int = 5,
        verbose: bool = False,
        skip_invalid_size_inputs=False,
        **kwargs
    ) -> List[List[Dict[str, torch.Tensor]]]:
        if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
            return self.generate(
                tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
            )[0]

        # build generator using current args as well as any kwargs
        gen_args = copy.copy(self.args)
        gen_args.beam = beam
        for k, v in kwargs.items():
            setattr(gen_args, k, v)
        generator = self.task.build_generator(gen_args)

        results = []
        for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
            batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
            translations = self.task.inference_step(generator, self.models, batch)
            for id, hypos in zip(batch["id"].tolist(), translations):
                results.append((id, hypos))

        # sort output to match input order
        outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]

        if verbose:

            def getarg(name, default):
                return getattr(gen_args, name, getattr(self.args, name, default))

            for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
                src_str_with_unk = self.string(source_tokens)
                print('S\t{}'.format(src_str_with_unk))
                for hypo in target_hypotheses:
                    hypo_str = self.decode(hypo['tokens'])
                    print('H\t{}\t{}'.format(hypo['score'], hypo_str))
                    print('P\t{}'.format(
                        ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                    ))
                    if hypo['alignment'] is not None and getarg('print_alignment', False):
                        print('A\t{}'.format(
                            ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
                        ))
        return outputs
Exemplo n.º 8
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        assert hasattr(
            model.decoder,
            'adaptive_softmax') and model.decoder.adaptive_softmax is not None
        adaptive_softmax = model.decoder.adaptive_softmax

        net_output = model(**sample['net_input'])
        orig_target = model.get_targets(sample, net_output)

        nsentences = orig_target.size(0)
        orig_target = orig_target.view(-1)

        bsz = orig_target.size(0)

        logits, target = adaptive_softmax(net_output[0], orig_target)
        assert len(target) == len(logits)

        loss = net_output[0].new(1 if reduce else bsz).zero_()

        for i in range(len(target)):
            if target[i] is not None:
                assert (target[i].min() >= 0
                        and target[i].max() <= logits[i].size(1))
                loss += F.cross_entropy(
                    logits[i],
                    target[i],
                    ignore_index=self.padding_idx,
                    reduction='sum' if reduce else 'none',
                )

        orig = utils.strip_pad(orig_target, self.padding_idx)
        ntokens = orig.numel()
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else ntokens
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': ntokens,
            'nsentences': nsentences,
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        assert (
            hasattr(model, 'classification_heads') and
            self.args.classification_head_name in model.classification_heads
        ), 'model must provide sentence classification head for --criterion=sentence_prediction'

        logits, _ = model(
            **sample['net_input'],
            features_only=True,
            classification_head_name=self.args.classification_head_name,
        )
        targets = model.get_targets(sample, [logits]).view(-1)
        sample_size = targets.numel()

        if not self.args.regression_target:
            loss = F.nll_loss(
                F.log_softmax(logits, dim=-1, dtype=torch.float32),
                targets,
                reduction='sum',
            )
        else:
            logits = logits.squeeze().float()
            targets = targets.float()
            loss = F.mse_loss(
                logits,
                targets,
                reduction='sum',
            )

        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample_size,
            'sample_size': sample_size,
        }
        if not self.args.regression_target:
            preds = logits.argmax(dim=1)
            logging_output['ncorrect'] = (preds == targets).sum()

        return loss, sample_size, logging_output
Exemplo n.º 10
0
    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                print('deleting {0}'.format(weights_key))
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
Exemplo n.º 11
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum().item()

        # (Rare case) When all tokens are masked, the model results in empty
        # tensor and gives CUDA error.
        if sample_size == 0:
            masked_tokens = None

        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])

        if sample_size != 0:
            targets = targets[masked_tokens]

        loss = F.nll_loss(
            F.log_softmax(
                logits.view(-1, logits.size(-1)),
                dim=-1,
                dtype=torch.float32,
            ),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
Exemplo n.º 12
0
    def forward(self, model, sample, reduce=True, log_pred=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        logits = model.get_logits(net_output).float()
        target = model.get_targets(sample, net_output,
                                   expand_steps=False).float()

        if hasattr(model, 'get_target_weights'):
            weights = model.get_target_weights(target, net_output)
            if torch.is_tensor(weights):
                weights = weights.float()
        else:
            weights = 1.

        loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False)

        loss = loss * weights

        if reduce:
            loss = loss.sum()

        sample_size = target.numel()
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample_size,
            'nsentences': logits.size(0),
            'sample_size': sample_size,
        }
        if log_pred:
            logging_output['logits'] = logits.cpu().numpy()
            logging_output['target'] = target.cpu().numpy()
        return loss, sample_size, logging_output
Exemplo n.º 13
0
            def forward(self, model, sample, reduce=True):
                net_outputs = model(**sample['net_input'])
                targets = sample['target']

                bsz = targets[0].size(0)
                loss = net_outputs[0][0].new(
                    1 if reduce else bsz).float().zero_()

                sample_size = 0
                logging_output = {}
                for o, t in zip(net_outputs[0], targets):
                    m = FakeModel(model, (o, net_outputs[1]), t)
                    sample['target'] = t
                    l, ss, logging_output = self.underlying_criterion(
                        m, sample, reduce)
                    loss += l
                    sample_size += ss

                loss.div_(len(targets))
                sample_size /= len(targets)

                logging_output['loss'] = utils.item(
                    loss.data) if reduce else loss.data
                return loss, sample_size, logging_output
Exemplo n.º 14
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        nsentences, ntokens = sample["nsentences"], sample["ntokens"]

        # B x T
        src_tokens, src_lengths = (
            sample["net_input"]["src_tokens"],
            sample["net_input"]["src_lengths"],
        )
        tgt_tokens, prev_output_tokens = sample["target"], sample[
            "prev_target"]

        outputs = model(src_tokens, src_lengths, prev_output_tokens,
                        tgt_tokens)
        losses, nll_loss = [], []

        for obj in outputs:
            if outputs[obj].get("loss", None) is None:
                _losses = self._compute_loss(outputs[obj].get("out"),
                                             outputs[obj].get("tgt"),
                                             outputs[obj].get("mask", None),
                                             outputs[obj].get("ls", 0.0),
                                             name=obj + '-loss',
                                             factor=outputs[obj].get(
                                                 "factor", 1.0))
            else:
                _losses = self._custom_loss(outputs[obj].get("loss"),
                                            name=obj + '-loss',
                                            factor=outputs[obj].get(
                                                "factor", 1.0))

            losses += [_losses]
            if outputs[obj].get("nll_loss", False):
                nll_loss += [_losses.get("nll_loss", 0.0)]

        loss = sum(l["loss"] for l in losses)
        nll_loss = sum(l for l in nll_loss) if len(nll_loss) > 0 \
            else loss.new_tensor(0)

        # NOTE:
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }

        for l in losses:
            logging_output[l["name"]] = (utils.item(
                l["loss"].data / l["factor"]) if reduce else l[["loss"]].data /
                                         l["factor"])

        return loss, sample_size, logging_output
Exemplo n.º 15
0
    def _get_loss(self, sample, model, criterion):
        assert hasattr(criterion, 'compute_loss'), \
            'translation_moe task requires the criterion to implement the compute_loss() method'

        k = self.args.num_experts
        bsz = sample['target'].size(0)

        def get_lprob_y(encoder_out, prev_output_tokens_k):
            net_output = model.decoder(
                prev_output_tokens=prev_output_tokens_k,
                encoder_out=encoder_out,
            )
            loss, _ = criterion.compute_loss(model,
                                             net_output,
                                             sample,
                                             reduce=False)
            loss = loss.view(bsz, -1)
            return -loss.sum(dim=1, keepdim=True)  # -> B x 1

        def get_lprob_yz(winners=None):
            encoder_out = model.encoder(
                src_tokens=sample['net_input']['src_tokens'],
                src_lengths=sample['net_input']['src_lengths'],
            )

            if winners is None:
                lprob_y = []
                for i in range(k):
                    prev_output_tokens_k = sample['net_input'][
                        'prev_output_tokens'].clone()
                    assert not prev_output_tokens_k.requires_grad
                    prev_output_tokens_k[:, 0] = self.expert_index(i)
                    lprob_y.append(
                        get_lprob_y(encoder_out, prev_output_tokens_k))
                lprob_y = torch.cat(lprob_y, dim=1)  # -> B x K
            else:
                prev_output_tokens_k = sample['net_input'][
                    'prev_output_tokens'].clone()
                prev_output_tokens_k[:, 0] = self.expert_index(winners)
                lprob_y = get_lprob_y(encoder_out,
                                      prev_output_tokens_k)  # -> B

            if self.uniform_prior:
                lprob_yz = lprob_y
            else:
                lprob_z = model.gating_network(encoder_out)  # B x K
                if winners is not None:
                    lprob_z = lprob_z.gather(dim=1,
                                             index=winners.unsqueeze(-1))
                lprob_yz = lprob_y + lprob_z.type_as(lprob_y)  # B x K

            return lprob_yz

        # compute responsibilities without dropout
        with utils.eval(model):  # disable dropout
            with torch.no_grad():  # disable autograd
                lprob_yz = get_lprob_yz()  # B x K
                prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1)
        assert not prob_z_xy.requires_grad

        # compute loss with dropout
        if self.hard_selection:
            winners = prob_z_xy.max(dim=1)[1]
            loss = -get_lprob_yz(winners)
        else:
            lprob_yz = get_lprob_yz()  # B x K
            loss = -modules.LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1)

        loss = loss.sum()
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data),
            'ntokens': sample['ntokens'],
            'nsentences': bsz,
            'sample_size': sample_size,
            'posterior': prob_z_xy.float().sum(dim=0).cpu(),
        }
        return loss, sample_size, logging_output
Exemplo n.º 16
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        lm_logits, output_metadata = model(**sample["net_input"])

        # reshape lm_logits from (N,T,C) to (N*T,C)
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        lm_targets = sample['lm_target'].view(-1)
        lm_loss = compute_cross_entropy_loss(lm_logits, lm_targets,
                                             self.padding_idx)

        # compute the number of tokens for which loss is computed. This is used
        # to normalize the loss
        ntokens = utils.strip_pad(lm_targets, self.padding_idx).numel()
        loss = lm_loss / ntokens
        nsentences = sample['nsentences']
        # nsentences = 0

        # Compute sentence loss if masked_lm_only is False
        sentence_loss = None
        if not self.args.masked_lm_only:
            sentence_logits = output_metadata['sentence_logits']
            sentence_targets = sample['sentence_target'].view(-1)
            # This needs to be recomputed due to some differences between
            # TokenBlock and BlockPair dataset. This can be resolved with a
            # refactor of BERTModel which we will do in the future.
            # TODO: Remove this after refactor of BERTModel
            nsentences = sentence_targets.size(0)

            # Check for logits being none which can happen when remove_heads
            # is set to true in the BERT model. Ideally we should set
            # masked_lm_only to true in this case, but that requires some
            # refactor in the BERT model.
            if sentence_logits is not None:
                sentence_loss = compute_cross_entropy_loss(
                    sentence_logits, sentence_targets)

                loss += self.args.nsp_loss_weight * (sentence_loss /
                                                     nsentences)

        # NOTE: as we are summing up per token mlm loss and per sentence nsp loss
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            'loss':
            utils.item(loss.data) if reduce else loss.data,
            'lm_loss':
            utils.item(lm_loss.data) if reduce else lm_loss.data,
            # sentence loss is not always computed
            'sentence_loss':
            ((utils.item(sentence_loss.data) if reduce else sentence_loss.data)
             if sentence_loss is not None else 0.0),
            'ntokens':
            ntokens,
            'nsentences':
            nsentences,
            'sample_size':
            sample_size,
        }
        return loss, sample_size, logging_output
Exemplo n.º 17
0
    def train_step(self, samples, dummy_batch=False, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch is None:
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        if not dummy_batch:
            metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args.distributed_world_size > 1
                        and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample, self.model, self.criterion, self.optimizer,
                        ignore_grad)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_size += sample_size_i
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    print(
                        "| WARNING: attempting to recover from OOM in forward/backward pass",
                        file=sys.stderr,
                    )
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if ooms > 0 and self._oom_batch is not None:
            self.handle_ooms(ooms)

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, sample_size, ooms = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ooms,
            )

        metrics.log_scalar("oom", ooms, len(samples), priority=600, round=3)
        if ooms == self.args.distributed_world_size * len(samples):
            print("| WARNING: OOM in all workers, skipping update")
            self.zero_grad()
            return None

        try:
            # normalize grads by sample size
            if sample_size > 0:
                if self._sync_stats():
                    # multiply gradients by (# GPUs / sample_size) since DDP
                    # already normalizes by the number of GPUs. Thus we get
                    # (sum_of_gradients / sample_size).
                    self.optimizer.multiply_grads(
                        self.args.distributed_world_size / sample_size)
                else:
                    self.optimizer.multiply_grads(1 / sample_size)

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # check that grad norms are consistent across workers
            if not self.args.use_bmuf:
                self._check_grad_norms(grad_norm)

            # take an optimization step
            self.optimizer.step()
            self.set_num_updates(self.get_num_updates() + 1)

            # task specific update per step
            self.task.update_step(self._num_updates)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs, sample_size)
            metrics.log_speed("ups", 1., priority=100, round=2)
            metrics.log_scalar("gnorm",
                               utils.item(grad_norm),
                               priority=400,
                               round=3)
            metrics.log_scalar(
                "clip",
                100 if grad_norm > self.args.clip_norm > 0 else 0,
                priority=500,
                round=1,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.args.empty_cache_freq > 0 and
                ((self.get_num_updates() + self.args.empty_cache_freq - 1) %
                 self.args.empty_cache_freq) == 0
                    and torch.cuda.is_available() and not self.args.cpu):
                torch.cuda.empty_cache()
        except OverflowError as e:
            print("| WARNING: overflow detected, " + str(e))
            self.zero_grad()
            logging_output = None
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                print("| ERROR: OOM during optimization, irrecoverable")
            raise e

        if self.args.fp16:
            metrics.log_scalar("loss_scale",
                               self.optimizer.scaler.loss_scale,
                               priority=700,
                               round=0)

        self.clear_buffered_stats()
        metrics.log_stop_time("train_wall")

        return logging_output