示例#1
0
    def build_embedding(args,
                        dictionary: Dictionary,
                        embedding_size,
                        path=None):
        emb = nn.Embedding(len(dictionary),
                           embedding_size,
                           padding_idx=dictionary.pad())
        if path is not None:
            embed_dict = nmtg.data.data_utils.parse_embedding(path)
            nmtg.data.data_utils.load_embedding(embed_dict, dictionary, emb)
        elif args.init_embedding == 'xavier':
            nn.init.xavier_uniform_(emb.weight)
        elif args.init_embedding == 'normal':
            nn.init.normal_(emb.weight, mean=0, std=embedding_size**-0.5)
        else:
            raise ValueError('Unknown initialization {}'.format(
                args.init_embedding))

        if args.freeze_embeddings:
            emb.weight.requires_grad_(False)

        return emb
示例#2
0
class NMTTrainer(Trainer):

    @classmethod
    def _add_inference_data_options(cls, parser, argv=None):
        parser.add_argument('-src_seq_length_trunc', type=int, default=0,
                            help='Truncate source sequences to this length. 0 (default) to disable')
        parser.add_argument('-tgt_seq_length_trunc', type=int, default=0,
                            help='Truncate target sequences to this length. 0 (default) to disable')

    @classmethod
    def add_inference_options(cls, parser, argv=None):
        super().add_inference_options(parser, argv)
        cls._add_inference_data_options(parser, argv)
        parser.add_argument('-input_type', default='word', choices=['word', 'char'],
                            help='Type of dictionary to create.')
        parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
        parser.add_argument('-alpha', type=float, default=0.6,
                            help='Length Penalty coefficient')
        parser.add_argument('-beta', type=float, default=0.0,
                            help='Coverage penalty coefficient')
        parser.add_argument('-normalize', action='store_true',
                            help='To normalize the scores based on output length')
        parser.add_argument('-n_best', type=int, default=1,
                            help='Will output the n_best decoded sentences')
        parser.add_argument('-label_smoothing', type=float, default=0.0,
                            help='Label smoothing value for loss functions.')
        parser.add_argument('-print_translations', action='store_true',
                            help='Output finished translations as they are generated')
        parser.add_argument('-return_scores', action='store_true',
                            help='Return scores in the online translation')

        parser.add_argument('-eval_noise', action='store_true',
                            help='Also apply noise when evaluating')
        parser.add_argument('-word_shuffle', type=int, default=3,
                            help='Maximum number of positions a word can move (0 to disable)')
        parser.add_argument('-word_blank', type=float, default=0.1,
                            help='Probability to replace a word with the unknown word (0 to disable)')
        parser.add_argument('-noise_word_dropout', type=float, default=0.1,
                            help='Probability to remove a word (0 to disable)')

    @classmethod
    def _add_train_data_options(cls, parser, argv=None):
        parser.add_argument('-train_src', type=str, required=True,
                            help='Path to the training source file')
        parser.add_argument('-train_tgt', type=str, required=True,
                            help='Path to the training target file')
        parser.add_argument('-join_vocab', action='store_true',
                            help='Share dictionary for source and target')
        parser.add_argument('-src_seq_length', type=int, default=64,
                            help='Discard source sequences above this length')
        parser.add_argument('-tgt_seq_length', type=int, default=64,
                            help='Discard target sequences above this length')
        parser.add_argument('-translation_noise', action='store_true',
                            help='Apply noise to the source when translating')
        parser.add_argument('-pre_word_vecs_enc', type=str,
                            help='If a valid path is specified, then this will load '
                                 'pretrained word embeddings on the encoder side. '
                                 'See README for specific formatting instructions.')
        parser.add_argument('-pre_word_vecs_dec', type=str,
                            help='If a valid path is specified, then this will load '
                                 'pretrained word embeddings on the decoder side. '
                                 'See README for specific formatting instructions.')

    @classmethod
    def add_training_options(cls, parser, argv=None):
        super().add_training_options(parser, argv)
        cls._add_train_data_options(parser, argv)
        parser.add_argument('-data_dir', type=str, required=True,
                            help='Path to an auxiliary data')
        parser.add_argument('-load_into_memory', action='store_true',
                            help='Load the dataset into memory')
        parser.add_argument('-batch_size_words', type=int, default=2048,
                            help='Maximum number of words in a batch')
        parser.add_argument('-batch_size_sents', type=int, default=128,
                            help='Maximum number of sentences in a batch')
        parser.add_argument('-batch_size_multiplier', type=int, default=1,
                            help='Number of sentences in a batch must be divisible by this number')
        parser.add_argument('-batch_size_update', type=int, default=20000,
                            help='Perform a learning step after this many tokens')
        parser.add_argument('-normalize_gradient', action='store_true',
                            help='Divide gradient by the number of tokens')
        parser.add_argument('-pad_count', action='store_true',
                            help='Count padding words when batching')

        parser.add_argument('-tie_weights', action='store_true',
                            help='Share weights between embedding and softmax')
        parser.add_argument('-freeze_embeddings', action='store_true',
                            help='Do not train word embeddings')
        parser.add_argument('-word_vec_size', type=int,
                            help='Word embedding sizes')
        parser.add_argument('-word_dropout', type=float, default=0.0,
                            help='Dropout probability; applied on embedding indices.')
        parser.add_argument('-init_embedding', default='normal', choices=['xavier', 'normal'],
                            help="How to init the embedding matrices.")
        parser.add_argument('-copy_decoder', action='store_true',
                            help='Use a decoder that will copy tokens from the input when it thinks it appropriate')
        parser.add_argument('-freeze_model', action='store_true',
                            help='Only used when upgrading an NMT Model without copy decoder.'
                                 'Freeze the model and only learn the copy decoder parameters')
        parser.add_argument('-extra_attention', action='store_true',
                            help='Add an extra attention layer at the end of the model to predict alignment for '
                                 'the copy decoder. For models like transformer, that have no clear attention '
                                 'alignment.')

    def _build_data(self):
        super()._build_data()

        if self.args.join_vocab:
            self.src_dict = Dictionary.load(os.path.join(self.args.data_dir, 'dict'))
            self.tgt_dict = self.src_dict
        else:
            self.src_dict = Dictionary.load(os.path.join(self.args.data_dir, 'src.dict'))
            self.tgt_dict = Dictionary.load(os.path.join(self.args.data_dir, 'tgt.dict'))
        logger.info('Vocabulary size: {:,d}|{:,d}'.format(len(self.src_dict), len(self.tgt_dict)))
        self._build_loss()

    def _load_data(self, checkpoint):
        super()._load_data(checkpoint)
        args = checkpoint['args']
        self.args.join_vocab = args.join_vocab

        if args.join_vocab:
            self.src_dict = Dictionary()
            self.src_dict.load_state_dict(checkpoint['dict'])
            self.tgt_dict = self.src_dict
        else:
            self.src_dict = Dictionary()
            self.src_dict.load_state_dict(checkpoint['src_dict'])
            self.tgt_dict = Dictionary()
            self.tgt_dict.load_state_dict(checkpoint['tgt_dict'])
        self._build_loss()

    def _save_data(self, checkpoint):
        super()._save_data(checkpoint)
        if self.args.join_vocab:
            checkpoint['dict'] = self.src_dict.state_dict()
        else:
            checkpoint['src_dict'] = self.src_dict.state_dict()
            checkpoint['tgt_dict'] = self.tgt_dict.state_dict()

    def _build_loss(self):
        logger.info('Building loss')
        loss = NMTLoss(len(self.tgt_dict), self.tgt_dict.pad(), self.args.label_smoothing)
        if self.args.cuda:
            loss.cuda()
        self.loss = loss

    def _build_model(self, model_args):
        logger.info('Building {} model'.format(model_args.model))
        model = build_model(model_args.model, model_args)

        embedding_size = model_args.word_vec_size or getattr(model_args, 'model_size', None)
        if embedding_size is None:
            raise ValueError('Could not infer embedding size')

        if model_args.copy_decoder and not model_args.join_vocab:
            raise NotImplementedError('In order to use the copy decoder, the source and target language must '
                                      'use the same vocabulary')

        if model_args.join_vocab and model_args.pre_word_vecs_dec:
            raise ValueError('Cannot join vocabularies when loading pre-trained target embeddings')

        dummy_input = torch.zeros(1, 1, embedding_size)
        dummy_output, _ = model(dummy_input, dummy_input)
        output_size = dummy_output.size(-1)

        src_embedding = self._get_embedding(model_args, self.src_dict, embedding_size,
                                            getattr(self.args, 'pre_word_vecs_enc', None))

        if model_args.join_vocab:
            tgt_embedding = src_embedding
        else:
            tgt_embedding = self._get_embedding(model_args, self.tgt_dict, embedding_size,
                                                getattr(self.args, 'pre_word_vecs_dev', None))

        tgt_linear = XavierLinear(output_size, len(self.tgt_dict))

        if model_args.tie_weights:
            tgt_linear.weight = tgt_embedding.weight

        encoder = NMTEncoder(model.encoder, src_embedding, model_args.word_dropout)

        if model_args.copy_decoder:
            masked_layers = getattr(model_args, 'masked_layers', False)
            attention_dropout = getattr(model_args, 'attn_dropout', 0.0)
            decoder = NMTDecoder(model.decoder, tgt_embedding, model_args.word_dropout, tgt_linear,
                                 copy_decoder=True,
                                 batch_first=model_args.batch_first,
                                 extra_attention=model_args.extra_attention,
                                 masked_layers=masked_layers,
                                 attention_dropout=attention_dropout)
        else:
            decoder = NMTDecoder(model.decoder, tgt_embedding, model_args.word_dropout, tgt_linear)

        if model_args.freeze_model:
            logger.info('Freezing model parameters')
            for param in itertools.chain(encoder.parameters(), decoder.decoder.parameters(),
                                         tgt_embedding.parameters(),
                                         tgt_linear.parameters()):
                param.requires_grad_(False)

        self.model = EncoderDecoderModel(encoder, decoder)
        self.model.batch_first = model_args.batch_first

    @staticmethod
    def _get_embedding(args, dictionary, embedding_size, path):
        emb = nn.Embedding(len(dictionary), embedding_size, padding_idx=dictionary.pad())
        if path is not None:
            embed_dict = data_utils.parse_embedding(path)
            data_utils.load_embedding(embed_dict, dictionary, emb)
        elif args.init_embedding == 'xavier':
            nn.init.xavier_uniform_(emb.weight)
        elif args.init_embedding == 'normal':
            nn.init.normal_(emb.weight, mean=0, std=embedding_size ** -0.5)
        else:
            raise ValueError('Unknown initialization {}'.format(args.init_embedding))

        if args.freeze_embeddings:
            emb.weight.requires_grad_(False)

        return emb

    def _get_text_lookup_dataset(self, task, text_dataset, src=True):
        split_words = self.args.input_type == 'word'
        dataset = TextLookupDataset(text_dataset,
                                    self.src_dict if src else self.tgt_dict,
                                    words=split_words,
                                    lower=task.lower,
                                    bos=not src, eos=not src,
                                    trunc_len=self.args.src_seq_length_trunc if src else self.args.tgt_seq_length_trunc)
        if text_dataset.in_memory:
            if split_words:
                lengths = np.array([len(sample.split()) for sample in text_dataset])
            else:
                lengths = np.array([len(sample) for sample in text_dataset])
        else:
            basename = os.path.basename(text_dataset.filename)
            lengths = np.load(os.path.join(self.args.data_dir, basename + '.len.npy'))
        dataset.lengths = lengths
        return dataset

    def _get_train_dataset(self):
        logger.info('Loading training data')
        split_words = self.args.input_type == 'word'

        src_data, src_lengths = TextLookupDataset.load(self.args.train_src, self.src_dict, self.args.data_dir,
                                                       self.args.load_into_memory, split_words,
                                                       bos=False, eos=False, trunc_len=self.args.src_seq_length_trunc,
                                                       lower=self.args.lower)

        if self.args.translation_noise:
            src_data = NoisyTextDataset(src_data, self.args.word_shuffle, self.args.noise_word_dropout,
                                        self.args.word_blank, self.args.bpe_symbol)

        tgt_data, tgt_lengths = TextLookupDataset.load(self.args.train_tgt, self.tgt_dict, self.args.data_dir,
                                                       self.args.load_into_memory, split_words,
                                                       bos=True, eos=True, trunc_len=self.args.tgt_seq_length_trunc,
                                                       lower=self.args.lower)
        src_data.lengths = src_lengths
        tgt_data.lengths = tgt_lengths
        dataset = ParallelDataset(src_data, tgt_data)
        logger.info('Number of training sentences: {:,d}'.format(len(dataset)))
        return dataset

    def _get_eval_dataset(self, task: TranslationTask):
        split_words = self.args.input_type == 'word'
        src_dataset = TextLookupDataset(task.src_dataset,
                                        self.src_dict,
                                        words=split_words,
                                        lower=task.lower,
                                        bos=False, eos=False,
                                        trunc_len=self.args.src_seq_length_trunc)

        if self.args.eval_noise:
            src_dataset = NoisyTextDataset(src_dataset, self.args.word_shuffle, self.args.noise_word_dropout,
                                           self.args.word_blank, self.args.bpe_symbol)

        if task.tgt_dataset is not None:
            tgt_dataset = TextLookupDataset(task.tgt_dataset,
                                            self.tgt_dict,
                                            words=split_words,
                                            lower=task.lower,
                                            bos=True, eos=True,
                                            trunc_len=self.args.tgt_seq_length_trunc)
        else:
            tgt_dataset = None
        dataset = ParallelDataset(src_dataset, tgt_dataset)
        return dataset

    def _get_train_sampler(self, dataset: ParallelDataset):
        src_lengths = dataset.src_data.lengths
        tgt_lengths = dataset.tgt_data.lengths

        def filter_fn(i):
            return src_lengths[i] <= self.args.src_seq_length and tgt_lengths[i] <= self.args.tgt_seq_length

        logger.info('Generating batches')
        batches = data_utils.generate_length_based_batches_from_lengths(
            np.maximum(src_lengths, tgt_lengths), self.args.batch_size_words,
            self.args.batch_size_sents,
            self.args.batch_size_multiplier,
            self.args.pad_count,
            key_fn=lambda i: (tgt_lengths[i], src_lengths[i]),
            filter_fn=filter_fn)
        logger.info('Number of training batches: {:,d}'.format(len(batches)))

        filtered = len(src_lengths) - sum(len(batch) for batch in batches)
        logger.info('Filtered {:,d}/{:,d} training examples for length'.format(filtered, len(src_lengths)))
        sampler = PreGeneratedBatchSampler(batches, self.args.curriculum == 0)
        return sampler

    def _get_training_metrics(self):
        metrics = super()._get_training_metrics()
        metrics['nll'] = AverageMeter()
        metrics['src_tps'] = AverageMeter()
        metrics['tgt_tps'] = AverageMeter()
        metrics['total_words'] = AverageMeter()
        return metrics

    def _reset_training_metrics(self, metrics):
        super()._reset_training_metrics(metrics)
        metrics['src_tps'].reset()
        metrics['tgt_tps'].reset()
        metrics['nll'].reset()

    def _format_train_metrics(self, metrics):
        formatted = super()._format_train_metrics(metrics)
        perplexity = math.exp(metrics['nll'].avg)
        formatted.insert(1, 'ppl {:6.2f}'.format(perplexity))

        srctok = metrics['src_tps'].sum / metrics['it_wall'].elapsed_time
        tgttok = metrics['tgt_tps'].sum / metrics['it_wall'].elapsed_time
        formatted.append('{:5.0f}|{:5.0f} tok/s'.format(srctok, tgttok))
        return formatted

    def _forward(self, batch, training=True):
        encoder_input = batch.get('src_indices')
        decoder_input = batch.get('tgt_input')
        targets = batch.get('tgt_output')

        if not self.model.batch_first:
            encoder_input = encoder_input.transpose(0, 1).contiguous()
            decoder_input = decoder_input.transpose(0, 1).contiguous()
            targets = targets.transpose(0, 1).contiguous()

        encoder_mask = encoder_input.ne(self.src_dict.pad())
        decoder_mask = decoder_input.ne(self.tgt_dict.pad())
        outputs, attn_out = self.model(encoder_input, decoder_input, encoder_mask, decoder_mask)
        lprobs = self.model.get_normalized_probs(outputs, attn_out, encoder_input,
                                            encoder_mask, decoder_mask, log_probs=True)
        if training:
            targets = targets.masked_select(decoder_mask)
        return self.loss(lprobs, targets)

    def _forward_backward_pass(self, batch, metrics):
        src_size = batch.get('src_size')
        tgt_size = batch.get('tgt_size')
        loss, display_loss = self._forward(batch)
        self.optimizer.backward(loss)
        metrics['nll'].update(display_loss, tgt_size)
        metrics['src_tps'].update(src_size)
        metrics['tgt_tps'].update(tgt_size)
        metrics['total_words'].update(tgt_size)

    def _do_training_step(self, metrics, batch):
        return metrics['total_words'].sum >= self.args.batch_size_update

    def _learning_step(self, metrics):
        if self.args.normalize_gradient:
            self.optimizer.multiply_grads(1 / metrics['total_words'].sum)
        super()._learning_step(metrics)
        metrics['total_words'].reset()

    def _get_eval_metrics(self):
        metrics = super()._get_eval_metrics()
        metrics['nll'] = AverageMeter()
        return metrics

    def format_eval_metrics(self, metrics):
        formatted = super().format_eval_metrics(metrics)
        formatted.append('Validation perplexity: {:.2f}'.format(math.exp(metrics['nll'].avg)))
        return formatted

    def _eval_pass(self, task, batch, metrics):
        tgt_size = batch.get('tgt_size')
        _, display_loss = self._forward(batch, training=False)
        metrics['nll'].update(display_loss, tgt_size)

    def _get_sequence_generator(self, task):
        return SequenceGenerator([self.model], self.tgt_dict, self.model.batch_first,
                                 self.args.beam_size, maxlen_b=20, normalize_scores=self.args.normalize,
                                 len_penalty=self.args.alpha, unk_penalty=self.args.beta)

    def _restore_src_string(self, task, output, join_str, bpe_symbol):
        return self.src_dict.string(output, join_str=join_str, bpe_symbol=bpe_symbol)

    def _restore_tgt_string(self, task, output, join_str, bpe_symbol):
        return self.tgt_dict.string(output, join_str=join_str, bpe_symbol=bpe_symbol)

    def solve(self, test_task):
        self.model.eval()

        generator = self._get_sequence_generator(test_task)

        test_dataset = self._get_eval_dataset(test_task)
        test_sampler = self._get_eval_sampler(test_dataset)
        test_iterator = self._get_iterator(test_dataset, test_sampler)

        results = []
        for batch in tqdm(test_iterator, desc='inference', disable=self.args.no_progress):

            res, src = self._inference_pass(test_task, batch, generator)

            if self.args.print_translations:
                for i, source in enumerate(src):
                    tqdm.write("Src {}: {}".format(len(results) + i, source))
                    for j in range(self.args.n_best):
                        translation = res[i * self.args.n_best + j]['tokens']
                        tqdm.write("Hyp {}.{}: {}".format(len(results) + i, j + 1,
                                                          translation.replace(self.args.bpe_symbol, '')))
                    tqdm.write("")

            results.extend(beam['tokens'] for beam in res)

        return results

    def online_translate(self, in_stream, **kwargs):
        self.model.eval()
        split_words = self.args.input_type == 'word'

        task = TranslationTask(in_stream, bpe_symbol=self.args.bpe_symbol, lower=self.args.lower, **kwargs)

        generator = self._get_sequence_generator(task)

        for j, line in enumerate(in_stream):
            line = line.rstrip()
            if self.args.lower:
                line = line.lower()
            if split_words:
                line = line.split()

            src_indices = self.src_dict.to_indices(line, bos=False, eos=False)
            encoder_inputs = src_indices.unsqueeze(0 if self.model.batch_first else 1)
            source_lengths = torch.tensor([len(line)])

            if self.args.cuda:
                encoder_inputs = encoder_inputs.cuda()
                source_lengths = source_lengths.cuda()

            batch = {'src_indices': encoder_inputs, 'src_lengths': source_lengths}

            res, src = self._inference_pass(task, batch, generator)
            source = src[0]

            if self.args.print_translations:
                tqdm.write("Src {}: {}".format(j, source))
                for i in range(self.args.n_best):
                    translation = res[i]['tokens']
                    tqdm.write("Hyp {}.{}: {}".format(j, i + 1,
                                                      translation.replace(self.args.bpe_symbol, '')))
                tqdm.write("")

            scores = [r['scores'] for r in res]
            positional_scores = [r['positional_scores'] for r in res]

            if len(res) == 1:
                res = res[0]
                scores = scores[0]
                positional_scores = positional_scores[0]

            if self.args.return_scores:
                yield res, scores, positional_scores.tolist()
            else:
                yield res

    def _inference_pass(self, task, batch, generator):
        encoder_input = batch.get('src_indices')
        source_lengths = batch.get('src_lengths')
        join_str = ' ' if self.args.input_type == 'word' else ''

        if not generator.batch_first:
            encoder_input = encoder_input.transpose(0, 1).contiguous()

        encoder_mask = encoder_input.ne(self.src_dict.pad())

        res = [tr
               for beams in generator.generate(encoder_input, source_lengths, encoder_mask)
               for tr in beams[:self.args.n_best]]

        for beam in res:
            beam['tokens'] = self.tgt_dict.string(beam['tokens'], join_str=join_str)

        src = []
        if self.args.print_translations:
            for i in range(len(batch['src_indices'])):
                ind = batch['src_indices'][i][:batch['src_lengths'][i]]
                ind = self.src_dict.string(ind, join_str=join_str, bpe_symbol=self.args.bpe_symbol)
                src.append(ind)
        return res, src

    @classmethod
    def upgrade_checkpoint(cls, checkpoint):
        super().upgrade_checkpoint(checkpoint)
        args = checkpoint['args']
        if 'freeze_model' not in args:
            args.freeze_model = False
            args.copy_decoder = False
            args.extra_attention = False
        if 'eval_noise' not in args:
            args.translation_noise = getattr(args, 'translation_noise', False)
            args.eval_noise = False