コード例 #1
0
ファイル: distributed_utils.py プロジェクト: fyabc/fairseq
def all_gather_list(data, max_size=16384):
    """Gathers arbitrary data from all nodes into a list."""
    world_size = torch.distributed.get_world_size()
    if not hasattr(all_gather_list, '_in_buffer') or \
            max_size != all_gather_list._in_buffer.size():
        all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
        all_gather_list._out_buffers = [
            torch.cuda.ByteTensor(max_size)
            for i in range(world_size)
        ]
    in_buffer = all_gather_list._in_buffer
    out_buffers = all_gather_list._out_buffers

    enc = pickle.dumps(data)
    enc_size = len(enc)
    if enc_size + 2 > max_size:
        raise ValueError('encoded data exceeds max_size: {}'.format(enc_size + 2))
    assert max_size < 255*256
    in_buffer[0] = enc_size // 255  # this encoding works for max_size < 65k
    in_buffer[1] = enc_size % 255
    in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))

    torch.distributed.all_gather(out_buffers, in_buffer.cuda())

    result = []
    for i in range(world_size):
        out_buffer = out_buffers[i]
        size = (255 * utils.item(out_buffer[0])) + utils.item(out_buffer[1])
        result.append(
            pickle.loads(bytes(out_buffer[2:size+2].tolist()))
        )
    return result
コード例 #2
0
ファイル: transformer.py プロジェクト: fyabc/fairseq
    def upgrade_state_dict(self, state_dict):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            if 'decoder.embed_positions.weights' in state_dict:
                del state_dict['decoder.embed_positions.weights']
            state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor(1)

        for i in range(len(self.layers)):
            # update layer norms
            layer_norm_map = {
                '0': 'self_attn_layer_norm',
                '1': 'encoder_attn_layer_norm',
                '2': 'final_layer_norm'
            }
            for old, new in layer_norm_map.items():
                for m in ('weight', 'bias'):
                    k = 'decoder.layers.{}.layer_norms.{}.{}'.format(i, old, m)
                    if k in state_dict:
                        state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
                        del state_dict[k]
        if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict['decoder.version'] = torch.Tensor([1])


        return state_dict
コード例 #3
0
def make_result(src_str, hypos, align_dict, tgt_dict, args):
    result = Translation(
        src_str='O\t{}'.format(src_str),
        hypos=[],
        pos_scores=[],
        alignments=[],
    )

    # Process top predictions
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
            hypo_tokens=hypo['tokens'].int().cpu(),
            src_str=src_str,
            alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
            align_dict=align_dict,
            tgt_dict=tgt_dict,
            remove_bpe=args.remove_bpe,
        )
        # result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
        # only get the traduction, not the score
        result.hypos.append(hypo_str)
        result.pos_scores.append('P\t{}'.format(
            ' '.join(map(
                lambda x: '{:.4f}'.format(x),
                hypo['positional_scores'].tolist(),
            ))
        ))
        result.alignments.append(
            'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
            if args.print_alignment else None
        )
    return result
コード例 #4
0
ファイル: adaptive_loss.py プロジェクト: fyabc/fairseq
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        assert hasattr(model.decoder, 'adaptive_softmax') and model.decoder.adaptive_softmax is not None
        adaptive_softmax = model.decoder.adaptive_softmax

        net_output = model(**sample['net_input'])
        target = model.get_targets(sample, net_output).view(-1)

        bsz = target.size(0)

        logits, target = adaptive_softmax(net_output[0], target)
        assert len(target) == len(logits)

        loss = net_output[0].new(1 if reduce else bsz).zero_()

        for i in range(len(target)):
            if target[i] is not None:
                assert (target[i].min() >= 0 and target[i].max() <= logits[i].size(1))
                loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx,
                                        reduce=reduce)

        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #5
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #6
0
ファイル: transformer.py プロジェクト: fyabc/fairseq
 def upgrade_state_dict(self, state_dict):
     """Upgrade a (possibly old) state dict for new versions of fairseq."""
     if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
         if 'encoder.embed_positions.weights' in state_dict:
             del state_dict['encoder.embed_positions.weights']
         state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
     if utils.item(state_dict.get('encoder.version', torch.Tensor([1]))[0]) < 2:
         # earlier checkpoints did not normalize after the stack of layers
         self.layer_norm = None
         self.normalize = False
         state_dict['encoder.version'] = torch.Tensor([1])
     return state_dict
コード例 #7
0
ファイル: interactive.py プロジェクト: zzzzxciid/fairseq
    def make_result(src_str, hypos):
        result = Translation(
            src_str='O\t{}'.format(src_str),
            hypos=[],
            alignments=[],
        )

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
                tgt_dict=tgt_dict,
                remove_bpe=args.remove_bpe,
            )
            result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
            result.alignments.append('A\t{}'.format(' '.join(
                map(lambda x: str(utils.item(x)), alignment))))
        return result
コード例 #8
0
            def forward(self, model, sample, reduce=True):
                net_outputs = model(**sample['net_input'])
                targets = sample['target']

                bsz = targets[0].size(0)
                loss = net_outputs[0][0].new(1 if reduce else bsz).float().zero_()

                sample_size = 0
                logging_output = {}
                for o, t in zip(net_outputs[0], targets):
                    m = FakeModel(model, (o, net_outputs[1]), t)
                    sample['target'] = t
                    l, ss, logging_output = self.underlying_criterion(m, sample, reduce)
                    loss += l
                    sample_size += ss

                loss.div_(len(targets))
                sample_size /= len(targets)

                logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
                return loss, sample_size, logging_output
コード例 #9
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        target = sample['target']

        loss = vocab_parallel_cross_entropy(net_output[0].float(), target)
        loss = (loss * (target != self.padding_idx)).sum()
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #10
0
ファイル: cross_entropy.py プロジェクト: fyabc/fairseq
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1)
        loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
                          reduce=reduce)
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #11
0
ファイル: transformer.py プロジェクト: lahiruts/espresso
    def upgrade_state_dict_named(self, state_dict, name):
        """Upgrade a (possibly old) state dict for new versions of fairseq."""
        if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
            weights_key = '{}.embed_positions.weights'.format(name)
            if weights_key in state_dict:
                print('deleting {0}'.format(weights_key))
                del state_dict[weights_key]
            state_dict['{}.embed_positions._float_tensor'.format(
                name)] = torch.FloatTensor(1)
        for i in range(len(self.layers)):
            # update layer norms
            self.layers[i].upgrade_state_dict_named(
                state_dict, "{}.layers.{}".format(name, i))

        version_key = '{}.version'.format(name)
        if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
            # earlier checkpoints did not normalize after the stack of layers
            self.layer_norm = None
            self.normalize = False
            state_dict[version_key] = torch.Tensor([1])
        return state_dict
コード例 #12
0
ファイル: masked_lm.py プロジェクト: verashira/TSPNet
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        # compute MLM loss
        masked_tokens = sample['target'].ne(self.padding_idx)
        sample_size = masked_tokens.int().sum().item()

        # (Rare case) When all tokens are masked, the model results in empty
        # tensor and gives CUDA error.
        if sample_size == 0:
            masked_tokens = None

        logits = model(**sample['net_input'], masked_tokens=masked_tokens)[0]
        targets = model.get_targets(sample, [logits])

        if sample_size != 0:
            targets = targets[masked_tokens]

        loss = F.nll_loss(
            F.log_softmax(
                logits.view(-1, logits.size(-1)),
                dim=-1,
                dtype=torch.float32,
            ),
            targets.view(-1),
            reduction='sum',
            ignore_index=self.padding_idx,
        )
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #13
0
ファイル: dictionary.py プロジェクト: StatNLP/ada4asr
    def string(
        self,
        tensor,
        bpe_symbol=None,
        escape_unk=False,
        extra_symbols_to_ignore=None,
        unk_string=None,
    ):
        """Helper for converting a tensor of token indices to a string.

        Can optionally remove BPE symbols or escape <unk> words.
        """
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return "\n".join(
                self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
                for t in tensor
            )

        extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
        extra_symbols_to_ignore.add(self.eos())

        def token_string(i):
            if i == self.unk():
                if unk_string is not None:
                    return unk_string
                else:
                    return self.unk_string(escape_unk)
            else:
                return self[i]

        if hasattr(self, "bos_index"):
            extra_symbols_to_ignore.add(self.bos())

        sent = " ".join(
            token_string(i)
            for i in tensor
            if utils.item(i) not in extra_symbols_to_ignore
        )

        return data_utils.post_process(sent, bpe_symbol)
コード例 #14
0
    def iw_eval(self, model, sample, data_len, iw_nsample, reduce=True):
        """Compute the importance-weighted loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """

        tmp = []
        for _ in range(int(iw_nsample / model.infer_ns)):
            net_output = model.iw_forward(**sample['net_input'],
                                          data_len=data_len)

            # log [p(x, t, z) / q(t, z |x)]
            # (batch, infer_ns)
            log_ratio = self._compulte_iw_loss(model,
                                               net_output,
                                               sample,
                                               reduce=reduce)
            tmp.append(log_ratio)

        # (batch)
        ll_iw = torch.logsumexp(torch.cat(tmp, dim=-1),
                                dim=-1) - math.log(iw_nsample)
        ll_iw = -ll_iw.sum()

        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']

        nsentences = sample['target'].size(0) / model.infer_ns

        logging_output = {
            'nll_iw': utils.item(ll_iw.data) if reduce else ll_iw.data,
            'ntokens': sample['ntokens'] / model.infer_ns,
            'nsentences': nsentences,
            'sample_size': sample_size / model.infer_ns,
        }

        return ll_iw, sample_size, logging_output
コード例 #15
0
ファイル: nat_ctc_loss.py プロジェクト: George0828Zhang/NAT
    def forward(self, model, sample, reduce=True):
        pad_idx = model.pad
        eos_idx = model.eos
        blank_idx = model.blank_idx

        net_output = model(**sample["net_input"])
        model_logits, input_lengths = (net_output["word_ins"]["out"],
                                       net_output["word_ins"]["src_lengths"])

        lprobs = utils.log_softmax(model_logits, dim=-1).transpose(
            1, 0).contiguous()  # (T, B, C) for ctc loss

        pad_mask = (sample["target"] != pad_idx) & (sample["target"] !=
                                                    eos_idx)
        targets_flat = sample["target"].masked_select(pad_mask)
        target_lengths = pad_mask.long().sum(-1)

        with torch.backends.cudnn.flags(enabled=False):
            loss = F.ctc_loss(
                lprobs,
                targets_flat,
                input_lengths,
                target_lengths,
                blank=blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )

        ntokens = target_lengths.sum().item()

        sample_size = sample["target"].size(
            0) if self.sentence_avg else ntokens
        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ntokens": ntokens,
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }

        return loss, sample_size, logging_output
コード例 #16
0
    def forward(self, model, sample, reduce=True):
        net_output, probability = model(**sample["net_input"])
        loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = (
            sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
        )
        logging_output = {
            "loss": loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }

        alignment_loss = None

        # Calculate alignment loss for training set
        if "alignments" in sample and sample["alignments"] is not None:
            alignment_loss = compute_alignment_loss(sample, probability)

        if alignment_loss is not None:
            logging_output["alignment_loss"] = utils.item(alignment_loss.data)
            loss += self.align_lambda * alignment_loss
コード例 #17
0
ファイル: trainer.py プロジェクト: zsquaredz/XSum
    def _backward_and_opt(self, loss, grad_denom):
        oom = 0
        if loss is not None:
            try:
                # backward pass
                loss.backward()
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    oom = 1
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                    self.optimizer.zero_grad()
                else:
                    raise e

        # all-reduce grads and rescale by grad_denom
        if self.args.distributed_world_size > 1:
            grads = [p.grad.data for p in self.model.parameters() if p.requires_grad]
            distributed_utils.all_reduce_and_rescale_tensors(grads, grad_denom)
        else:
            for p in self.model.parameters():
                if p.requires_grad:
                    p.grad.data.div_(grad_denom)

        # clip grads
        if self.args.clip_norm > 0:
            grad_norm = utils.item(torch.nn.utils.clip_grad_norm(self.model.parameters(), self.args.clip_norm))
        else:
            grad_norm = math.sqrt(sum(p.grad.data.norm()**2 for p in self.model.parameters()))

        # take an optimization step
        self.optimizer.step()
        self._num_updates += 1

        # update learning rate
        self.lr_scheduler.step_update(self._num_updates)

        return grad_norm, oom
コード例 #18
0
 def swap_sample(self, sample):
     target = sample["target"]
     prev_output_tokens = sample["net_input"]["prev_output_tokens"]
     src_tokens = torch.cat(
         (prev_output_tokens[:, :1], sample["net_input"]['src_tokens']),
         dim=-1)
     return {
         "net_input": {
             "src_tokens": target.contiguous(),
             "src_lengths": (target != self.padding_idx).int().sum(dim=1),
             "prev_output_tokens": src_tokens[:, :-1].contiguous()
         },
         'nsentences':
         sample['nsentences'],
         'ntokens':
         utils.item(
             (src_tokens[:, 1:] != self.padding_idx).int().sum().data),
         "target":
         src_tokens[:, 1:].contiguous(),
         "id":
         sample["id"],
     }
コード例 #19
0
ファイル: hub_utils.py プロジェクト: ChenDdon/AGBTcode
    def generate(self,
                 tokens: torch.LongTensor,
                 beam: int = 5,
                 verbose: bool = False,
                 **kwargs) -> torch.LongTensor:
        sample = self._build_sample(tokens)

        # build generator using current args as well as any kwargs
        gen_args = copy.copy(self.args)
        gen_args.beam = beam
        for k, v in kwargs.items():
            setattr(gen_args, k, v)
        generator = self.task.build_generator(gen_args)

        translations = self.task.inference_step(generator, self.models, sample)

        if verbose:
            src_str_with_unk = self.string(tokens)
            print('S\t{}'.format(src_str_with_unk))

        def getarg(name, default):
            return getattr(gen_args, name, getattr(self.args, name, default))

        # Process top predictions
        hypos = translations[0]
        if verbose:
            for hypo in hypos:
                hypo_str = self.decode(hypo['tokens'])
                print('H\t{}\t{}'.format(hypo['score'], hypo_str))
                print('P\t{}'.format(' '.join(
                    map(lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist()))))
                if hypo['alignment'] is not None and getarg(
                        'print_alignment', False):
                    print('A\t{}'.format(' '.join(
                        map(lambda x: str(utils.item(x)),
                            hypo['alignment'].int().cpu()))))

        return hypos
コード例 #20
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        lprobs = model.get_normalized_probs(net_output, log_probs=True)
        lprobs = lprobs.view(-1, lprobs.size(-1))
        target = model.get_targets(sample, net_output).view(-1)
        loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
                          reduce=reduce)
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #21
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        copy_alpha = net_output[1]['copy_alpha'].mean().item(
        ) if net_output[1]['copy_alpha'] is not None else -1
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
            'copy_alpha': copy_alpha,
        }
        return loss, sample_size, logging_output
コード例 #22
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss, as a Variable
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        translations, bleu_scores = self.generate_translations(model, sample)
        nll_loss = self.compute_nll(model, sample, translations)
        loss = nll_loss[:, 0] + torch.logsumexp(-nll_loss, 1)
        if reduce:
            loss = loss.sum()
        sample_size = (sample["target"].size(0)
                       if self.args.sentence_avg else sample["ntokens"])
        logging_output = {
            "loss": utils.item(loss.data) if reduce else loss.data,
            "ntokens": sample["ntokens"],
            "nsentences": sample["target"].size(0),
            "sample_size": sample_size,
        }
        return loss, sample_size, logging_output
コード例 #23
0
    def forward(self, model, sample, reduce=True):
        logits, _ = model(
            **sample['net_input'],
            features_only=True,
            classification_head_name='question_answer_head',
        )

        import IPython
        IPython.embed()
        exit()

        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        start_positions = sample['targets']['starts']
        end_positions = sample['targets']['ends']

        ignored_index = start_logits.size(1)
        start_positions.clamp_(0, ignored_index)
        end_positions.clamp_(0, ignored_index)

        start_loss = self.compute_loss(start_logits, start_positions)
        end_loss = self.compute_loss(end_logits, end_positions)

        loss = (start_loss + end_loss) / 2

        sample_size = sample['nsentences']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['nsentences'],
            'sample_size': sample_size,
        }
        if self.do_evaluate:
            logging_output.update(starts=start_logits.detach())
            logging_output.update(ends=end_logits.detach())
        return loss, sample_size, logging_output
コード例 #24
0
    def forward(self, model, sample, reduce=True, log_pred=False):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        logits = model.get_logits(net_output).float()
        target = model.get_targets(sample, net_output,
                                   expand_steps=False).float()

        if hasattr(model, 'get_target_weights'):
            weights = model.get_target_weights(target, net_output)
            if torch.is_tensor(weights):
                weights = weights.float()
        else:
            weights = 1.

        loss = F.binary_cross_entropy_with_logits(logits, target, reduce=False)

        loss = loss * weights

        if reduce:
            loss = loss.sum()

        sample_size = target.numel()
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample_size,
            'nsentences': logits.size(0),
            'sample_size': sample_size,
        }
        if log_pred:
            logging_output['logits'] = logits.cpu().numpy()
            logging_output['target'] = target.cpu().numpy()
        return loss, sample_size, logging_output
コード例 #25
0
ファイル: write_loss.py プロジェクト: sbsyangjian/RetrieveNMT
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """



        net_output = model(**sample['net_input'])
        loss, _ = self.compute_loss(model, net_output, sample, reduce=False)
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']

        batch_size, length = model.get_targets(sample, net_output).size()
        sentence_loss = loss.view(batch_size, length).sum(1)
        id_list = sample["id"].tolist()
        loss_list = sentence_loss.tolist()
        length_list = sample["net_input"]["src_lengths"].tolist()
        results=[]
        for i,j,k in zip(id_list,loss_list,length_list):
            results.append((i,j/k))
        results.sort(key=lambda x:x[0])
        for i,j in results:
            self.w.write("{} {}\n".format(i,j))
            self.w.flush()
            if i % 1000000 == 0:
                print("id:{}".format(i))
        loss = loss.sum()

        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #26
0
    def forward(self, model, sample, reduce=True):
        net_output = model(**sample["net_input"])
        loss, nll_loss = self.compute_loss(model,
                                           net_output,
                                           sample,
                                           reduce=reduce)
        encoder_out = model.encoder.forward(
            sample["net_input"]["src_tokens"],
            sample["net_input"]["src_lengths"]).encoder_out
        reverse_sample = self.swap_sample(sample)
        reversed_encoder_out = model.encoder.forward(
            reverse_sample["net_input"]["src_tokens"],
            reverse_sample["net_input"]["src_lengths"]).encoder_out
        contrastive_loss = self.get_contrastive_loss(
            encoder_out,
            reversed_encoder_out,
            sample,
            reverse_sample,
        )
        sample_size = (sample["target"].size(0)
                       if self.sentence_avg else sample["ntokens"])
        nsentences = sample["target"].size(0)
        ntokens = sample["ntokens"]
        all_loss = loss + contrastive_loss * self.contrastive_lambda * ntokens / nsentences
        logging_output = {
            "loss": loss.data,
            "nll_loss": nll_loss.data,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }
        if isinstance(contrastive_loss, int):
            logging_output["contrastive_loss"] = 0
        else:
            logging_output["contrastive_loss"] = utils.item(
                contrastive_loss.data)

        return all_loss, sample_size, logging_output
コード例 #27
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        real_input = sample['net_input']
        real_input['prev_output_tokens'] = sample['target']
        real_output = model(**real_input)[0]
        fake_input = deepcopy(sample['net_input'])
        fake_input['prev_output_tokens'] = sample['generated_tokens']
        fake_output = model(**fake_input)[0]

        loss, _ = self.compute_loss(model, real_output, fake_output, sample, reduce=reduce)
        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #28
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        #import pdb
        #pdb.set_trace()
        net_output = model(**sample['net_input'])
        loss, nll_loss = self.compute_loss(model,
                                           net_output[0],
                                           sample,
                                           reduce=reduce)
        gsnn_loss, gsnn_nll_loss = None, None
        if net_output[1] is not None:
            gsnn_loss, gsnn_nll_loss = self.compute_loss(model,
                                                         net_output[1],
                                                         sample,
                                                         reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        KL_div = self.compute_kl_divergence(net_output[2], net_output[3])

        if KL_div is not None:
            logging_output["kl_div"] = utils.item(KL_div.data)
            loss += self.KL_lambda * KL_div
            if gsnn_loss is not None:
                logging_output['gsnn_loss'] = utils.item(
                    gsnn_loss.data) if reduce else gsnn_loss.data
                logging_output['gsnn_nll_loss'] = utils.item(
                    gsnn_nll_loss.data) if reduce else gsnn_nll_loss.data
                loss += self.alpha * gsnn_loss

        logging_output['total_loss'] = utils.item(
            loss.data) if reduce else loss.data
        return loss, sample_size, logging_output
コード例 #29
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, nll_loss, loss_sen_piece, nll_loss_sen_piece, overall_loss, overall_nll_loss = self.compute_loss(
            model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        sample_size_sen_piece = sample['target_sen_piece'].size(
            0) if self.args.sentence_avg else sample['ntokens_sen_piece']
        sample_size_overall = sample_size + sample_size_sen_piece
        logging_output = {
            'loss':
            utils.item(loss.data) if reduce else loss.data,
            'nll_loss':
            utils.item(nll_loss.data) if reduce else nll_loss.data,
            'loss_sen_piece':
            utils.item(loss_sen_piece.data) if reduce else loss_sen_piece.data,
            'nll_loss_sen_piece':
            utils.item(nll_loss_sen_piece.data)
            if reduce else nll_loss_sen_piece.data,
            'overall_loss':
            utils.item(overall_loss.data) if reduce else overall_loss.data,
            'overall_nll_loss':
            utils.item(overall_nll_loss.data)
            if reduce else overall_nll_loss.data,
            'ntokens':
            sample['ntokens'],
            'ntokens_sen_piece':
            sample['ntokens_sen_piece'],
            'nsentences':
            sample['target'].size(0),
            'sample_size':
            sample_size,
            'sample_size_sen_piece':
            sample_size_sen_piece,
            'sample_size_overall':
            sample_size_overall,
        }
        return loss, sample_size, loss_sen_piece, sample_size_sen_piece, overall_loss, sample_size_overall, logging_output
コード例 #30
0
    def forward(self, model, sample, reduce=True):
        net_outputs = model(**sample['net_input'])
        targets = sample['target']

        bsz = targets[0].size(0)
        loss = net_outputs[0][0].new(1 if reduce else bsz).zero_()

        sample_size = 0
        logging_output = {}
        for i, (o, t) in enumerate(zip(net_outputs[0], targets)):
            m = CompositeLoss.FakeModel(model, (o, net_outputs[1]), t)
            l, ss, logging_output = self.underlying_criterion(
                m, sample, reduce)
            if self.weights is not None:
                l *= self.weights[i]
            loss += l
            sample_size += ss

        loss.div_(len(targets))
        sample_size /= len(targets)

        logging_output['loss'] = utils.item(loss.data) if reduce else loss.data
        return loss, sample_size, logging_output
コード例 #31
0
    def _get_qe_loss(self,
                     sample,
                     predictor,
                     estimator,
                     criterion,
                     xml_model=None,
                     valid=False,
                     xml_estimator=None):
        assert hasattr(criterion, 'compute_loss'), \
            'translation_moe task requires the criterion to implement the compute_loss() method'

        k = self.args.num_experts
        bsz = sample['target'].size(0)

        # compute loss with dropout
        ter_prediction = self.get_experts_decoder_output(
            predictor,
            estimator,
            sample,
            xml_model=xml_model,
            xml_estimator=xml_estimator).squeeze(-1)
        ter_gt = sample['ter']
        # ter_gt = torch.ones_like(ter_gt)*3
        # if not valid:
        loss = self.loss_fn(ter_prediction, ter_gt)
        loss = loss.sum()
        # else:
        #     loss = np.corrcoef(ter_prediction.squeeze(-1).cpu().data, torch.Tensor(sample['ter']).squeeze(-1).data)[0, 1]
        # sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        sample_size = bsz
        logging_output = {
            'loss': utils.item(loss if valid else loss.data),
            'ntokens': sample['ntokens'],
            'sample_size': sample_size,
            'posterior': ter_prediction.cpu(),
        }
        return loss, sample_size, logging_output
コード例 #32
0
    def forward(self, model, sample, reduce=True, compute_custom_metrics=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        logits = net_output[0].view(-1, net_output[0].size(-1))
        target = model.get_targets(sample, net_output)
        target = target.view(-1)
        loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']

        true_token_logits = -F.nll_loss(
            logits,
            target,
            ignore_index=self.padding_idx,
            reduction='none',  # I think this needs to be mean for batch case?
        )
        orig = utils.strip_pad(target, self.padding_idx)
        ntokens = orig.numel()

        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        if compute_custom_metrics:
            custom_output = TrainingMetrics.ranking_metrics(
                logits, true_token_logits, sample, ntokens, target)
            for k, v in custom_output.items():
                logging_output[k] = v
        return loss, sample_size, logging_output
コード例 #33
0
def processFlatHypo(sample_id, src_tokens, target_tokens, hypos, src_str,
                    align_dict, tgt_dict, remove_bpe, has_target, target_str):
    """Not used"""
    for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
            hypo_tokens=hypo['tokens'].int().cpu(),
            src_str=src_str,
            alignment=hypo['alignment'].int().cpu(),
            align_dict=align_dict,
            tgt_dict=tgt_dict,
            remove_bpe=remove_bpe,
        )

        if not args.quiet:
            print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
            print('P-{}\t{}'.format(
                sample_id, ' '.join(
                    map(
                        lambda x: '{:.4f}'.format(x),
                        hypo['positional_scores'].tolist(),
                    ))))
            print('A-{}\t{}'.format(
                sample_id,
                ' '.join(map(lambda x: str(utils.item(x)), alignment))))

        # Score only the top hypothesis
        if has_target and i == 0:
            if align_dict is not None or args.remove_bpe is not None:
                # Convert back to tokens for evaluation with unk replacement and/or without BPE
                target_tokens = tokenizer.Tokenizer.tokenize(
                    target_str, tgt_dict, add_if_not_exist=True)

        # write files for ROUGE
        with open(os.path.join(args.decode_dir, "{}.dec".format(sample_id)),
                  'w') as f:
            f.write(make_html_safe(hypo_str))
            f.close()
コード例 #34
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss_seq, nll_loss_seq = self.compute_loss(model,
                                                   net_output,
                                                   sample,
                                                   reduce=reduce)
        loss_pos, nll_loss_pos = self.compute_pointer_loss(net_output,
                                                           sample,
                                                           reduce=reduce)
        # import pdb; pdb.set_trace()
        loss = loss_seq + self.loss_coef * loss_pos
        # loss = loss_seq
        nll_loss = nll_loss_seq + self.loss_coef * nll_loss_pos
        # TODO use different normalization factor for two types of losses
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss':
            utils.item(loss.data) if reduce else loss.data,
            'nll_loss':
            utils.item(nll_loss.data) if reduce else nll_loss.data,
            'loss_seq':
            utils.item(loss_seq.data) if reduce else loss_seq.data,
            'nll_loss_seq':
            utils.item(nll_loss_seq.data) if reduce else nll_loss_seq.data,
            'loss_pos':
            utils.item(loss_pos.data) if reduce else loss_pos.data,
            'nll_loss_pos':
            utils.item(nll_loss_pos.data) if reduce else nll_loss_pos.data,
            'ntokens':
            sample['ntokens'],
            'nsentences':
            sample['target'].size(0),
            'sample_size':
            sample_size,
        }
        return loss, sample_size, logging_output
コード例 #35
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        target = model.get_targets(sample, net_output).view(-1, 1)
        if self.xentropy_func is not None:
            assert (net_output[0].dtype == torch.float16) or (net_output[0].dtype == torch.float32), "Unsupported data types"
            output = net_output[0].view(net_output[0].size(0)*net_output[0].size(1),net_output[0].size(2))
            labels = target.view(target.size(0)*target.size(1))
            losses = self.xentropy_func(output, labels, self.eps, self.padding_idx, net_output[0].dtype == torch.float16)
            loss   = losses.sum()
        else :
            lprobs = model.get_normalized_probs(net_output, log_probs=True)
            lprobs = lprobs.view(-1, lprobs.size(-1))
            non_pad_mask = target.ne(self.padding_idx)
            nll_loss = -lprobs.gather(dim=-1, index=target)[non_pad_mask]
            smooth_loss = -lprobs.sum(dim=-1, keepdim=True)[non_pad_mask]
            if reduce:
                nll_loss = nll_loss.sum()
                smooth_loss = smooth_loss.sum()
            eps_i = self.eps / lprobs.size(-1)
            loss = (1. - self.eps) * nll_loss + eps_i * smooth_loss

        sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            #'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
            'ntokens': sample['ntokens'],
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #36
0
    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        net_output = model(**sample['net_input'])
        loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
        eta = model.eta if self.eta_hardcoded is None else self.eta_hardcoded
        residual = loss - eta
        loss = self.relu(residual) / self.alpha + eta
        loss = loss * (target != self.padding_idx).float()
        loss = torch.sum(loss)
        sample_size = sample['target'].size(
            0) if self.args.sentence_avg else sample['ntokens']
        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample['target'].size(0),
            'sample_size': sample_size,
        }
        return loss, sample_size, logging_output
コード例 #37
0
    def forward(self, model, sample, reduce=True):
        """Compute ranking loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        scores = []
        for idx in range(self.args.num_classes):
            score, _ = model(
                **sample['net_input{idx}'.format(idx=idx + 1)],
                features_only=True,
                classification_head_name='sentence_classification_head',
            )
            scores.append(score)

        logits = torch.cat(scores, dim=1)
        targets = model.get_targets(sample, [logits]).view(-1)
        sample_size = targets.numel()

        loss = F.nll_loss(
            F.log_softmax(logits, dim=-1, dtype=torch.float32),
            targets,
            reduction='sum',
        )

        logging_output = {
            'loss': utils.item(loss.data) if reduce else loss.data,
            'ntokens': sample['ntokens'],
            'nsentences': sample_size,
            'sample_size': sample_size,
        }
        logging_output.update(ncorrect=(logits.max(
            dim=1)[1] == targets).sum().item())
        return loss, sample_size, logging_output
コード例 #38
0
ファイル: test_utils.py プロジェクト: fyabc/fairseq
 def assertAlmostEqual(self, t1, t2):
     self.assertEqual(t1.size(), t2.size(), "size mismatch")
     self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
コード例 #39
0
ファイル: eval_lm.py プロジェクト: fyabc/fairseq
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'

    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
    print(args)

    task.args = args

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        if args.fp16:
            model.half()

    assert len(models) > 0

    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        ignore_invalid_inputs=True,
    ).next_epoch_itr(shuffle=False)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models, task.target_dictionary)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
        bpe_len = len(bpe_cont)
    else:
        bpe_toks = None
        bpe_len = 0

    word_stats = dict()

    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += utils.item(pos_scores.sum())
                count += pos_scores.numel() - skipped_toks

                if args.output_word_probs or args.output_word_stats:
                    w = ''
                    word_prob = []
                    is_bpe = False
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
                            is_bpe = True
                        else:
                            word_prob.append((w, pos_scores[i].item()))
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
                            w = ''
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)
コード例 #40
0
ファイル: generate.py プロジェクト: fyabc/fairseq
def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.raw_text, \
        '--replace-unk requires a raw text dataset (--raw-text)'

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Load ensemble
    print('| loading model(s) from {}'.format(args.path))
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    # Load dataset (possibly sharded)
    itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=8,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)

    # Initialize generator
    gen_timer = StopwatchMeter()
    if args.score_reference:
        translator = SequenceScorer(models, task.target_dictionary)
    else:
        translator = SequenceGenerator(
            models, task.target_dictionary, beam_size=args.beam, minlen=args.min_len,
            stop_early=(not args.no_early_stop), normalize_scores=(not args.unnormalized),
            len_penalty=args.lenpen, unk_penalty=args.unkpen,
            sampling=args.sampling, sampling_topk=args.sampling_topk, sampling_temperature=args.sampling_temperature,
            diverse_beam_groups=args.diverse_beam_groups, diverse_beam_strength=args.diverse_beam_strength,
        )

    if use_cuda:
        translator.cuda()

    # Generate and compute BLEU score
    scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
    num_sentences = 0
    has_target = True
    with progress_bar.build_progress_bar(args, itr) as t:
        if args.score_reference:
            translations = translator.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        else:
            translations = translator.generate_batched_itr(
                t, maxlen_a=args.max_len_a, maxlen_b=args.max_len_b,
                cuda=use_cuda, timer=gen_timer, prefix_size=args.prefix_size,
            )

        wps_meter = TimeMeter()
        for sample_id, src_tokens, target_tokens, hypos in translations:
            # Process input and ground truth
            has_target = target_tokens is not None
            target_tokens = target_tokens.int().cpu() if has_target else None

            # Either retrieve the original sentences or regenerate them from tokens.
            if align_dict is not None:
                src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
            else:
                src_str = src_dict.string(src_tokens, args.remove_bpe)
                if has_target:
                    target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

            if not args.quiet:
                print('S-{}\t{}'.format(sample_id, src_str))
                if has_target:
                    print('T-{}\t{}'.format(sample_id, target_str))

            # Process top predictions
            for i, hypo in enumerate(hypos[:min(len(hypos), args.nbest)]):
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
                    align_dict=align_dict,
                    tgt_dict=tgt_dict,
                    remove_bpe=args.remove_bpe,
                )

                if not args.quiet:
                    print('H-{}\t{}\t{}'.format(sample_id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        sample_id,
                        ' '.join(map(
                            lambda x: '{:.4f}'.format(x),
                            hypo['positional_scores'].tolist(),
                        ))
                    ))

                    if args.print_alignment:
                        print('A-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(lambda x: str(utils.item(x)), alignment))
                        ))

                # Score only the top hypothesis
                if has_target and i == 0:
                    if align_dict is not None or args.remove_bpe is not None:
                        # Convert back to tokens for evaluation with unk replacement and/or without BPE
                        target_tokens = tokenizer.Tokenizer.tokenize(
                            target_str, tgt_dict, add_if_not_exist=True)
                    scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += 1

    print('| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))