def _build_data(self): super()._build_data() if self.args.join_vocab: self.src_dict = Dictionary.load(os.path.join(self.args.data_dir, 'dict')) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary.load(os.path.join(self.args.data_dir, 'src.dict')) self.tgt_dict = Dictionary.load(os.path.join(self.args.data_dir, 'tgt.dict')) logger.info('Vocabulary size: {:,d}|{:,d}'.format(len(self.src_dict), len(self.tgt_dict))) self._build_loss()
def preprocess(cls, args): split_words = args.input_type == 'word' os.makedirs(args.data_dir_out, exist_ok=True) basename = os.path.basename(args.in_data) ((offsets, lengths, counter), ) = cls.get_indices_and_vocabulary([args.in_data], split_words, args.lower, not args.no_progress, args.report_every) out_offsets = os.path.join(args.data_dir_out, basename + '.idx.npy') out_lengths = os.path.join(args.data_dir_out, basename + '.len.npy') np.save(out_offsets, offsets) np.save(out_lengths, lengths) dictionary = Dictionary() for word, count in counter.items(): dictionary.add_symbol(word, count) dictionary.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) dictionary.save(os.path.join(args.data_dir_out, 'dict'))
def convert_checkpoint(checkpoint): logger.info('Converting old checkpoint...') train_data = { 'model': flatten_state_dict( convert_nmt_model(checkpoint['opt'], unflatten_state_dict(checkpoint['model']))), 'lr_scheduler': { 'best': None }, 'training_time': 0.0 } if 'optim' in checkpoint: num_updates = checkpoint['optim']['_step'] del checkpoint['optim']['_step'] train_data['optimizer'] = checkpoint['optim'] train_data['num_updates'] = num_updates if 'epoch' in checkpoint: train_data['epoch'] = checkpoint['epoch'] if 'iteration' in checkpoint: train_data['sampler'] = { 'index': checkpoint['iteration'], 'batch_order': checkpoint['batchOrder'] } new_checkpoint = {'train_data': train_data} # Dictionaries src_state_dict = Dictionary.convert( checkpoint['dicts']['src']).state_dict() join_vocab = checkpoint['dicts']['src'].labelToIdx == checkpoint['dicts'][ 'tgt'].labelToIdx if join_vocab: new_checkpoint['dict'] = src_state_dict else: new_checkpoint['src_dict'] = src_state_dict tgt_state_dict = Dictionary.convert( checkpoint['dicts']['tgt']).state_dict() new_checkpoint['tgt_dict'] = tgt_state_dict args = checkpoint['opt'] args.join_vocab = join_vocab args.word_vec_size = None input_chars = all(len(x[0]) == 1 for x in src_state_dict['dict']) args.input_type = 'char' if input_chars else 'word' new_checkpoint['args'] = args return new_checkpoint
def __init__(self, args): super().__init__(args) if hasattr(args, 'data_dir'): logger.info('Loading vocabularies from {}'.format(args.data_dir)) if args.join_vocab: self.src_dict = Dictionary.load(os.path.join(args.data_dir, 'dict')) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary.load(os.path.join(args.data_dir, 'src.dict')) self.tgt_dict = Dictionary.load(os.path.join(args.data_dir, 'tgt.dict')) self.loss = self._build_loss() logger.info('Vocabulary size: {:,d}|{:,d}'.format(len(self.src_dict), len(self.tgt_dict))) else: self.src_dict = None self.tgt_dict = None
def preprocess(cls, args): dictionaries = [Dictionary.load(filename) for filename in args.dicts] dictionary = dictionaries[0] for x in dictionaries[1:]: dictionary.update(x) dictionary.save(os.path.join(args.data_dir_out, args.out_name))
def preprocess(args): split_words = args.input_type == 'word' os.makedirs(args.data_dir_out, exist_ok=True) train_clean_name = os.path.basename(args.train_clean) source_files = [args.train_clean] if args.train_noisy is not None: source_files.append(args.train_noisy) outputs = get_indices_and_vocabulary(source_files, split_words, args.lower, not args.no_progress, args.report_every) if args.train_noisy is not None: train_noisy_name = os.path.basename(args.train_noisy) (offsets, lengths, counter), \ (noisy_offsets, noisy_lengths, noisy_counter) = outputs counter.update(noisy_counter) noisy_offset_filename = os.path.join(args.data_dir_out, train_noisy_name + '.idx.npy') np.save(noisy_offset_filename, noisy_offsets) else: ((offsets, lengths, counter), ) = outputs out_offsets = os.path.join(args.data_dir_out, train_clean_name + '.idx.npy') out_lengths = os.path.join(args.data_dir_out, train_clean_name + '.len.npy') np.save(out_offsets, offsets) np.save(out_lengths, lengths) if args.vocab is not None: dictionary = Dictionary.load(args.vocab) else: dictionary = Dictionary() for word, count in counter.items(): dictionary.add_symbol(word, count) dictionary.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) dictionary.save(os.path.join(args.data_dir_out, 'dict'))
def load_state_dict(self, state_dict): super().load_state_dict(state_dict) if self.args.join_vocab: self.src_dict = Dictionary() self.src_dict.load_state_dict(state_dict['dict']) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary() self.src_dict.load_state_dict(state_dict['src_dict']) self.tgt_dict = Dictionary() self.tgt_dict.load_state_dict(state_dict['tgt_dict']) self.loss = self._build_loss()
def _load_data(self, checkpoint): super()._load_data(checkpoint) args = checkpoint['args'] self.args.join_vocab = args.join_vocab if args.join_vocab: self.src_dict = Dictionary() self.src_dict.load_state_dict(checkpoint['dict']) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary() self.src_dict.load_state_dict(checkpoint['src_dict']) self.tgt_dict = Dictionary() self.tgt_dict.load_state_dict(checkpoint['tgt_dict']) self._build_loss()
def build_embedding(args, dictionary: Dictionary, embedding_size, path=None): emb = nn.Embedding(len(dictionary), embedding_size, padding_idx=dictionary.pad()) if path is not None: embed_dict = nmtg.data.data_utils.parse_embedding(path) nmtg.data.data_utils.load_embedding(embed_dict, dictionary, emb) elif args.init_embedding == 'xavier': nn.init.xavier_uniform_(emb.weight) elif args.init_embedding == 'normal': nn.init.normal_(emb.weight, mean=0, std=embedding_size**-0.5) else: raise ValueError('Unknown initialization {}'.format( args.init_embedding)) if args.freeze_embeddings: emb.weight.requires_grad_(False) return emb
def preprocess(cls, args, save_data=True): split_words = args.input_type == 'word' os.makedirs(args.data_dir_out, exist_ok=True) dictionaries = [] dataset_lengths = [] for filename in args.in_data: ((offsets, lengths, counter), ) = cls.get_indices_and_vocabulary( [filename], split_words, args.lower, not args.no_progress, args.report_every) basename = os.path.basename(filename) out_offsets = os.path.join(args.data_dir_out, basename + '.idx.npy') out_lengths = os.path.join(args.data_dir_out, basename + '.len.npy') np.save(out_offsets, offsets) np.save(out_lengths, lengths) dictionaries.append(counter) dataset_lengths.append(len(lengths)) pairs = len(args.langs[0].split('-')) == 2 data_mode = 'pairs' if pairs else 'all_to_all' if data_mode == 'all_to_all': assert len(set(args.langs)) == len(args.langs) if not args.join_src_tgt_vocab: raise ValueError( 'In order to use all_to_all data mode, vocabularies must be shared across' 'source and target languages') if not len(set(dataset_lengths)) == 1: raise ValueError('Datasets are not the same length') src_langs = args.langs tgt_langs = args.langs src_counters = dictionaries tgt_counters = dictionaries else: src_langs, tgt_langs = zip(*(lang.split('-') for lang in args.langs)) src_counters, tgt_counters = {}, {} for lang, counter in zip(src_langs, dictionaries[::2]): if lang in src_counters: src_counters[lang].update(counter) else: src_counters[lang] = counter for lang, counter in zip(tgt_langs, dictionaries[1::2]): if lang in tgt_counters: tgt_counters[lang].update(counter) else: tgt_counters[lang] = counter src_langs = list(src_counters.keys()) tgt_langs = list(tgt_counters.keys()) src_counters = [src_counters[lang] for lang in src_langs] tgt_counters = [tgt_counters[lang] for lang in tgt_langs] if args.join_lang_vocab and args.join_src_tgt_vocab: dictionary = Dictionary.from_counters(*(src_counters + tgt_counters)) dictionary.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) dictionary.save(os.path.join(args.data_dir_out, 'dict')) elif args.join_lang_vocab and not args.join_src_tgt_vocab: src_dict = Dictionary.from_counters(*src_counters) tgt_dict = Dictionary.from_counters(*tgt_counters) src_dict.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) src_dict.save(os.path.join(args.data_dir_out, 'src.dict')) tgt_dict.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) tgt_dict.save(os.path.join(args.data_dir_out, 'tgt.dict')) elif not args.join_lang_vocab and args.join_src_tgt_vocab: vocabs = {} for lang, counter in zip(src_langs + tgt_langs, src_counters + tgt_counters): if lang in vocabs: for word, count in counter: vocabs[lang].add_symbol(word, count) else: vocabs[lang] = Dictionary.from_counters(counter) for lang, vocab in vocabs.items(): vocab.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) vocab.save(os.path.join(args.data_dir_out, lang + '.dict')) else: vocabs = {} for lang, counter in zip(src_langs, src_counters): if lang + '.src' in vocabs: voc = vocabs[lang + '.src'] for word, count in counter: voc.add_symbol(word, count) else: vocabs[lang + '.src'] = Dictionary.from_counters(counter) for lang, counter in zip(tgt_langs, tgt_counters): if lang + '.tgt' in vocabs: voc = vocabs[lang + '.tgt'] for word, count in counter: voc.add_symbol(word, count) else: vocabs[lang + '.tgt'] = Dictionary.from_counters(counter) for lang, vocab in vocabs.items(): vocab.finalize(nwords=args.vocab_size, threshold=args.vocab_threshold or -1) vocab.save(os.path.join(args.data_dir_out, lang + '.dict'))
class NMTTrainer(Trainer): @classmethod def _add_inference_data_options(cls, parser, argv=None): parser.add_argument('-src_seq_length_trunc', type=int, default=0, help='Truncate source sequences to this length. 0 (default) to disable') parser.add_argument('-tgt_seq_length_trunc', type=int, default=0, help='Truncate target sequences to this length. 0 (default) to disable') @classmethod def add_inference_options(cls, parser, argv=None): super().add_inference_options(parser, argv) cls._add_inference_data_options(parser, argv) parser.add_argument('-input_type', default='word', choices=['word', 'char'], help='Type of dictionary to create.') parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-alpha', type=float, default=0.6, help='Length Penalty coefficient') parser.add_argument('-beta', type=float, default=0.0, help='Coverage penalty coefficient') parser.add_argument('-normalize', action='store_true', help='To normalize the scores based on output length') parser.add_argument('-n_best', type=int, default=1, help='Will output the n_best decoded sentences') parser.add_argument('-label_smoothing', type=float, default=0.0, help='Label smoothing value for loss functions.') parser.add_argument('-print_translations', action='store_true', help='Output finished translations as they are generated') parser.add_argument('-return_scores', action='store_true', help='Return scores in the online translation') parser.add_argument('-eval_noise', action='store_true', help='Also apply noise when evaluating') parser.add_argument('-word_shuffle', type=int, default=3, help='Maximum number of positions a word can move (0 to disable)') parser.add_argument('-word_blank', type=float, default=0.1, help='Probability to replace a word with the unknown word (0 to disable)') parser.add_argument('-noise_word_dropout', type=float, default=0.1, help='Probability to remove a word (0 to disable)') @classmethod def _add_train_data_options(cls, parser, argv=None): parser.add_argument('-train_src', type=str, required=True, help='Path to the training source file') parser.add_argument('-train_tgt', type=str, required=True, help='Path to the training target file') parser.add_argument('-join_vocab', action='store_true', help='Share dictionary for source and target') parser.add_argument('-src_seq_length', type=int, default=64, help='Discard source sequences above this length') parser.add_argument('-tgt_seq_length', type=int, default=64, help='Discard target sequences above this length') parser.add_argument('-translation_noise', action='store_true', help='Apply noise to the source when translating') parser.add_argument('-pre_word_vecs_enc', type=str, help='If a valid path is specified, then this will load ' 'pretrained word embeddings on the encoder side. ' 'See README for specific formatting instructions.') parser.add_argument('-pre_word_vecs_dec', type=str, help='If a valid path is specified, then this will load ' 'pretrained word embeddings on the decoder side. ' 'See README for specific formatting instructions.') @classmethod def add_training_options(cls, parser, argv=None): super().add_training_options(parser, argv) cls._add_train_data_options(parser, argv) parser.add_argument('-data_dir', type=str, required=True, help='Path to an auxiliary data') parser.add_argument('-load_into_memory', action='store_true', help='Load the dataset into memory') parser.add_argument('-batch_size_words', type=int, default=2048, help='Maximum number of words in a batch') parser.add_argument('-batch_size_sents', type=int, default=128, help='Maximum number of sentences in a batch') parser.add_argument('-batch_size_multiplier', type=int, default=1, help='Number of sentences in a batch must be divisible by this number') parser.add_argument('-batch_size_update', type=int, default=20000, help='Perform a learning step after this many tokens') parser.add_argument('-normalize_gradient', action='store_true', help='Divide gradient by the number of tokens') parser.add_argument('-pad_count', action='store_true', help='Count padding words when batching') parser.add_argument('-tie_weights', action='store_true', help='Share weights between embedding and softmax') parser.add_argument('-freeze_embeddings', action='store_true', help='Do not train word embeddings') parser.add_argument('-word_vec_size', type=int, help='Word embedding sizes') parser.add_argument('-word_dropout', type=float, default=0.0, help='Dropout probability; applied on embedding indices.') parser.add_argument('-init_embedding', default='normal', choices=['xavier', 'normal'], help="How to init the embedding matrices.") parser.add_argument('-copy_decoder', action='store_true', help='Use a decoder that will copy tokens from the input when it thinks it appropriate') parser.add_argument('-freeze_model', action='store_true', help='Only used when upgrading an NMT Model without copy decoder.' 'Freeze the model and only learn the copy decoder parameters') parser.add_argument('-extra_attention', action='store_true', help='Add an extra attention layer at the end of the model to predict alignment for ' 'the copy decoder. For models like transformer, that have no clear attention ' 'alignment.') def _build_data(self): super()._build_data() if self.args.join_vocab: self.src_dict = Dictionary.load(os.path.join(self.args.data_dir, 'dict')) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary.load(os.path.join(self.args.data_dir, 'src.dict')) self.tgt_dict = Dictionary.load(os.path.join(self.args.data_dir, 'tgt.dict')) logger.info('Vocabulary size: {:,d}|{:,d}'.format(len(self.src_dict), len(self.tgt_dict))) self._build_loss() def _load_data(self, checkpoint): super()._load_data(checkpoint) args = checkpoint['args'] self.args.join_vocab = args.join_vocab if args.join_vocab: self.src_dict = Dictionary() self.src_dict.load_state_dict(checkpoint['dict']) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary() self.src_dict.load_state_dict(checkpoint['src_dict']) self.tgt_dict = Dictionary() self.tgt_dict.load_state_dict(checkpoint['tgt_dict']) self._build_loss() def _save_data(self, checkpoint): super()._save_data(checkpoint) if self.args.join_vocab: checkpoint['dict'] = self.src_dict.state_dict() else: checkpoint['src_dict'] = self.src_dict.state_dict() checkpoint['tgt_dict'] = self.tgt_dict.state_dict() def _build_loss(self): logger.info('Building loss') loss = NMTLoss(len(self.tgt_dict), self.tgt_dict.pad(), self.args.label_smoothing) if self.args.cuda: loss.cuda() self.loss = loss def _build_model(self, model_args): logger.info('Building {} model'.format(model_args.model)) model = build_model(model_args.model, model_args) embedding_size = model_args.word_vec_size or getattr(model_args, 'model_size', None) if embedding_size is None: raise ValueError('Could not infer embedding size') if model_args.copy_decoder and not model_args.join_vocab: raise NotImplementedError('In order to use the copy decoder, the source and target language must ' 'use the same vocabulary') if model_args.join_vocab and model_args.pre_word_vecs_dec: raise ValueError('Cannot join vocabularies when loading pre-trained target embeddings') dummy_input = torch.zeros(1, 1, embedding_size) dummy_output, _ = model(dummy_input, dummy_input) output_size = dummy_output.size(-1) src_embedding = self._get_embedding(model_args, self.src_dict, embedding_size, getattr(self.args, 'pre_word_vecs_enc', None)) if model_args.join_vocab: tgt_embedding = src_embedding else: tgt_embedding = self._get_embedding(model_args, self.tgt_dict, embedding_size, getattr(self.args, 'pre_word_vecs_dev', None)) tgt_linear = XavierLinear(output_size, len(self.tgt_dict)) if model_args.tie_weights: tgt_linear.weight = tgt_embedding.weight encoder = NMTEncoder(model.encoder, src_embedding, model_args.word_dropout) if model_args.copy_decoder: masked_layers = getattr(model_args, 'masked_layers', False) attention_dropout = getattr(model_args, 'attn_dropout', 0.0) decoder = NMTDecoder(model.decoder, tgt_embedding, model_args.word_dropout, tgt_linear, copy_decoder=True, batch_first=model_args.batch_first, extra_attention=model_args.extra_attention, masked_layers=masked_layers, attention_dropout=attention_dropout) else: decoder = NMTDecoder(model.decoder, tgt_embedding, model_args.word_dropout, tgt_linear) if model_args.freeze_model: logger.info('Freezing model parameters') for param in itertools.chain(encoder.parameters(), decoder.decoder.parameters(), tgt_embedding.parameters(), tgt_linear.parameters()): param.requires_grad_(False) self.model = EncoderDecoderModel(encoder, decoder) self.model.batch_first = model_args.batch_first @staticmethod def _get_embedding(args, dictionary, embedding_size, path): emb = nn.Embedding(len(dictionary), embedding_size, padding_idx=dictionary.pad()) if path is not None: embed_dict = data_utils.parse_embedding(path) data_utils.load_embedding(embed_dict, dictionary, emb) elif args.init_embedding == 'xavier': nn.init.xavier_uniform_(emb.weight) elif args.init_embedding == 'normal': nn.init.normal_(emb.weight, mean=0, std=embedding_size ** -0.5) else: raise ValueError('Unknown initialization {}'.format(args.init_embedding)) if args.freeze_embeddings: emb.weight.requires_grad_(False) return emb def _get_text_lookup_dataset(self, task, text_dataset, src=True): split_words = self.args.input_type == 'word' dataset = TextLookupDataset(text_dataset, self.src_dict if src else self.tgt_dict, words=split_words, lower=task.lower, bos=not src, eos=not src, trunc_len=self.args.src_seq_length_trunc if src else self.args.tgt_seq_length_trunc) if text_dataset.in_memory: if split_words: lengths = np.array([len(sample.split()) for sample in text_dataset]) else: lengths = np.array([len(sample) for sample in text_dataset]) else: basename = os.path.basename(text_dataset.filename) lengths = np.load(os.path.join(self.args.data_dir, basename + '.len.npy')) dataset.lengths = lengths return dataset def _get_train_dataset(self): logger.info('Loading training data') split_words = self.args.input_type == 'word' src_data, src_lengths = TextLookupDataset.load(self.args.train_src, self.src_dict, self.args.data_dir, self.args.load_into_memory, split_words, bos=False, eos=False, trunc_len=self.args.src_seq_length_trunc, lower=self.args.lower) if self.args.translation_noise: src_data = NoisyTextDataset(src_data, self.args.word_shuffle, self.args.noise_word_dropout, self.args.word_blank, self.args.bpe_symbol) tgt_data, tgt_lengths = TextLookupDataset.load(self.args.train_tgt, self.tgt_dict, self.args.data_dir, self.args.load_into_memory, split_words, bos=True, eos=True, trunc_len=self.args.tgt_seq_length_trunc, lower=self.args.lower) src_data.lengths = src_lengths tgt_data.lengths = tgt_lengths dataset = ParallelDataset(src_data, tgt_data) logger.info('Number of training sentences: {:,d}'.format(len(dataset))) return dataset def _get_eval_dataset(self, task: TranslationTask): split_words = self.args.input_type == 'word' src_dataset = TextLookupDataset(task.src_dataset, self.src_dict, words=split_words, lower=task.lower, bos=False, eos=False, trunc_len=self.args.src_seq_length_trunc) if self.args.eval_noise: src_dataset = NoisyTextDataset(src_dataset, self.args.word_shuffle, self.args.noise_word_dropout, self.args.word_blank, self.args.bpe_symbol) if task.tgt_dataset is not None: tgt_dataset = TextLookupDataset(task.tgt_dataset, self.tgt_dict, words=split_words, lower=task.lower, bos=True, eos=True, trunc_len=self.args.tgt_seq_length_trunc) else: tgt_dataset = None dataset = ParallelDataset(src_dataset, tgt_dataset) return dataset def _get_train_sampler(self, dataset: ParallelDataset): src_lengths = dataset.src_data.lengths tgt_lengths = dataset.tgt_data.lengths def filter_fn(i): return src_lengths[i] <= self.args.src_seq_length and tgt_lengths[i] <= self.args.tgt_seq_length logger.info('Generating batches') batches = data_utils.generate_length_based_batches_from_lengths( np.maximum(src_lengths, tgt_lengths), self.args.batch_size_words, self.args.batch_size_sents, self.args.batch_size_multiplier, self.args.pad_count, key_fn=lambda i: (tgt_lengths[i], src_lengths[i]), filter_fn=filter_fn) logger.info('Number of training batches: {:,d}'.format(len(batches))) filtered = len(src_lengths) - sum(len(batch) for batch in batches) logger.info('Filtered {:,d}/{:,d} training examples for length'.format(filtered, len(src_lengths))) sampler = PreGeneratedBatchSampler(batches, self.args.curriculum == 0) return sampler def _get_training_metrics(self): metrics = super()._get_training_metrics() metrics['nll'] = AverageMeter() metrics['src_tps'] = AverageMeter() metrics['tgt_tps'] = AverageMeter() metrics['total_words'] = AverageMeter() return metrics def _reset_training_metrics(self, metrics): super()._reset_training_metrics(metrics) metrics['src_tps'].reset() metrics['tgt_tps'].reset() metrics['nll'].reset() def _format_train_metrics(self, metrics): formatted = super()._format_train_metrics(metrics) perplexity = math.exp(metrics['nll'].avg) formatted.insert(1, 'ppl {:6.2f}'.format(perplexity)) srctok = metrics['src_tps'].sum / metrics['it_wall'].elapsed_time tgttok = metrics['tgt_tps'].sum / metrics['it_wall'].elapsed_time formatted.append('{:5.0f}|{:5.0f} tok/s'.format(srctok, tgttok)) return formatted def _forward(self, batch, training=True): encoder_input = batch.get('src_indices') decoder_input = batch.get('tgt_input') targets = batch.get('tgt_output') if not self.model.batch_first: encoder_input = encoder_input.transpose(0, 1).contiguous() decoder_input = decoder_input.transpose(0, 1).contiguous() targets = targets.transpose(0, 1).contiguous() encoder_mask = encoder_input.ne(self.src_dict.pad()) decoder_mask = decoder_input.ne(self.tgt_dict.pad()) outputs, attn_out = self.model(encoder_input, decoder_input, encoder_mask, decoder_mask) lprobs = self.model.get_normalized_probs(outputs, attn_out, encoder_input, encoder_mask, decoder_mask, log_probs=True) if training: targets = targets.masked_select(decoder_mask) return self.loss(lprobs, targets) def _forward_backward_pass(self, batch, metrics): src_size = batch.get('src_size') tgt_size = batch.get('tgt_size') loss, display_loss = self._forward(batch) self.optimizer.backward(loss) metrics['nll'].update(display_loss, tgt_size) metrics['src_tps'].update(src_size) metrics['tgt_tps'].update(tgt_size) metrics['total_words'].update(tgt_size) def _do_training_step(self, metrics, batch): return metrics['total_words'].sum >= self.args.batch_size_update def _learning_step(self, metrics): if self.args.normalize_gradient: self.optimizer.multiply_grads(1 / metrics['total_words'].sum) super()._learning_step(metrics) metrics['total_words'].reset() def _get_eval_metrics(self): metrics = super()._get_eval_metrics() metrics['nll'] = AverageMeter() return metrics def format_eval_metrics(self, metrics): formatted = super().format_eval_metrics(metrics) formatted.append('Validation perplexity: {:.2f}'.format(math.exp(metrics['nll'].avg))) return formatted def _eval_pass(self, task, batch, metrics): tgt_size = batch.get('tgt_size') _, display_loss = self._forward(batch, training=False) metrics['nll'].update(display_loss, tgt_size) def _get_sequence_generator(self, task): return SequenceGenerator([self.model], self.tgt_dict, self.model.batch_first, self.args.beam_size, maxlen_b=20, normalize_scores=self.args.normalize, len_penalty=self.args.alpha, unk_penalty=self.args.beta) def _restore_src_string(self, task, output, join_str, bpe_symbol): return self.src_dict.string(output, join_str=join_str, bpe_symbol=bpe_symbol) def _restore_tgt_string(self, task, output, join_str, bpe_symbol): return self.tgt_dict.string(output, join_str=join_str, bpe_symbol=bpe_symbol) def solve(self, test_task): self.model.eval() generator = self._get_sequence_generator(test_task) test_dataset = self._get_eval_dataset(test_task) test_sampler = self._get_eval_sampler(test_dataset) test_iterator = self._get_iterator(test_dataset, test_sampler) results = [] for batch in tqdm(test_iterator, desc='inference', disable=self.args.no_progress): res, src = self._inference_pass(test_task, batch, generator) if self.args.print_translations: for i, source in enumerate(src): tqdm.write("Src {}: {}".format(len(results) + i, source)) for j in range(self.args.n_best): translation = res[i * self.args.n_best + j]['tokens'] tqdm.write("Hyp {}.{}: {}".format(len(results) + i, j + 1, translation.replace(self.args.bpe_symbol, ''))) tqdm.write("") results.extend(beam['tokens'] for beam in res) return results def online_translate(self, in_stream, **kwargs): self.model.eval() split_words = self.args.input_type == 'word' task = TranslationTask(in_stream, bpe_symbol=self.args.bpe_symbol, lower=self.args.lower, **kwargs) generator = self._get_sequence_generator(task) for j, line in enumerate(in_stream): line = line.rstrip() if self.args.lower: line = line.lower() if split_words: line = line.split() src_indices = self.src_dict.to_indices(line, bos=False, eos=False) encoder_inputs = src_indices.unsqueeze(0 if self.model.batch_first else 1) source_lengths = torch.tensor([len(line)]) if self.args.cuda: encoder_inputs = encoder_inputs.cuda() source_lengths = source_lengths.cuda() batch = {'src_indices': encoder_inputs, 'src_lengths': source_lengths} res, src = self._inference_pass(task, batch, generator) source = src[0] if self.args.print_translations: tqdm.write("Src {}: {}".format(j, source)) for i in range(self.args.n_best): translation = res[i]['tokens'] tqdm.write("Hyp {}.{}: {}".format(j, i + 1, translation.replace(self.args.bpe_symbol, ''))) tqdm.write("") scores = [r['scores'] for r in res] positional_scores = [r['positional_scores'] for r in res] if len(res) == 1: res = res[0] scores = scores[0] positional_scores = positional_scores[0] if self.args.return_scores: yield res, scores, positional_scores.tolist() else: yield res def _inference_pass(self, task, batch, generator): encoder_input = batch.get('src_indices') source_lengths = batch.get('src_lengths') join_str = ' ' if self.args.input_type == 'word' else '' if not generator.batch_first: encoder_input = encoder_input.transpose(0, 1).contiguous() encoder_mask = encoder_input.ne(self.src_dict.pad()) res = [tr for beams in generator.generate(encoder_input, source_lengths, encoder_mask) for tr in beams[:self.args.n_best]] for beam in res: beam['tokens'] = self.tgt_dict.string(beam['tokens'], join_str=join_str) src = [] if self.args.print_translations: for i in range(len(batch['src_indices'])): ind = batch['src_indices'][i][:batch['src_lengths'][i]] ind = self.src_dict.string(ind, join_str=join_str, bpe_symbol=self.args.bpe_symbol) src.append(ind) return res, src @classmethod def upgrade_checkpoint(cls, checkpoint): super().upgrade_checkpoint(checkpoint) args = checkpoint['args'] if 'freeze_model' not in args: args.freeze_model = False args.copy_decoder = False args.extra_attention = False if 'eval_noise' not in args: args.translation_noise = getattr(args, 'translation_noise', False) args.eval_noise = False
# quan_transformer.encoder.layer_modules[0].multihead.attn_dropout.register_forward_hook(lambda m, i, o: print(i, o)) inputs = {'source': encoder_input, 'target_input': decoder_input} output_dict = quan_transformer(inputs) outputs_quan = generator(output_dict["hiddens"], False).clone().detach().cpu() loss_quan = loss_function_quan(output_dict, decoder_input, generator, backward=True)['loss'].clone().detach().cpu() grads_quan = encoder.layer_modules[0].multihead.fc_query.function.linear.weight.grad.clone().detach().cpu() grads_quan2 = decoder.layer_modules[-1].multihead_src.fc_concat.function.linear.weight.grad.clone().detach().cpu() optim.zero_grad() print("Making Felix Transformer") dictionary = Dictionary() felix_transformer = Transformer.build_model(args) felix_transformer = NMTModel(NMTEncoder(felix_transformer.encoder, embedding_src, args.word_dropout), NMTDecoder(felix_transformer.decoder, embedding_tgt, args.word_dropout, generator.linear), dictionary, dictionary) loss_function_felix = NMTLoss(30000, onmt.Constants.PAD, 0.0) felix_transformer.cuda() loss_function_felix.cuda() print(len(list(felix_transformer.parameters())), len(list(quan_transformer.parameters()))) print(sum(p.numel() for p in felix_transformer.parameters())) print(sum(p.numel() for p in quan_transformer.parameters())) # share params... felix_transformer.encoder.encoder.postprocess.layer_norm.function.weight = quan_transformer.encoder.postprocess_layer.layer_norm.function.weight felix_transformer.encoder.encoder.postprocess.layer_norm.function.bias = quan_transformer.encoder.postprocess_layer.layer_norm.function.bias
class NMTTrainer(Trainer): @classmethod def add_preprocess_options(cls, parser): super().add_preprocess_options(parser) parser.add_argument('-train_src', type=str, required=True, help='Path to the training source file') parser.add_argument('-train_tgt', type=str, required=True, help='Path to the training target file') parser.add_argument('-src_vocab', type=str, help='Path to an existing source vocabulary') parser.add_argument('-tgt_vocab', type=str, help='Path to an existing target vocabulary') parser.add_argument('-data_dir_out', type=str, required=True, help='Output directory for auxiliary data') parser.add_argument('-lower', action='store_true', help='Construct a lower-case vocabulary') parser.add_argument('-vocab_threshold', type=int, help='Discard vocabulary words that occur less often than this threshold') # parser.add_argument('-remove_duplicate', action='store_true', # help='Remove examples where source and target are the same') parser.add_argument('-join_vocab', action='store_true', help='Share dictionary for source and target') parser.add_argument('-src_vocab_size', type=int, default=50000, help='Size of the source vocabulary') parser.add_argument('-tgt_vocab_size', type=int, default=50000, help='Size of the target vocabulary') parser.add_argument('-input_type', default='word', choices=['word', 'char'], help='Type of dictionary to create.') parser.add_argument('-report_every', type=int, default=100000, help='Report status every this many sentences') @classmethod def add_general_options(cls, parser): super().add_general_options(parser) parser.add_argument('-input_type', default='word', choices=['word', 'char'], help='Type of dictionary to create.') parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-alpha', type=float, default=0.6, help='Length Penalty coefficient') parser.add_argument('-beta', type=float, default=0.0, help='Coverage penalty coefficient') parser.add_argument('-normalize', action='store_true', help='To normalize the scores based on output length') parser.add_argument('-n_best', type=int, default=1, help='Will output the n_best decoded sentences') parser.add_argument('-label_smoothing', type=float, default=0.0, help='Label smoothing value for loss functions.') parser.add_argument('-print_translations', action='store_true', help='Output finished translations as they are generated') # Currently used, but pointless parser.add_argument('-diverse_beam_strength', type=float, default=0.5, help='Diverse beam strength in decoding') @classmethod def add_training_options(cls, parser): super().add_training_options(parser) NMTModel.add_options(parser) parser.add_argument('-train_src', type=str, required=True, help='Path to the training source file') parser.add_argument('-train_tgt', type=str, required=True, help='Path to the training target file') parser.add_argument('-data_dir', type=str, required=True, help='Path to an auxiliary data') parser.add_argument('-load_into_memory', action='store_true', help='Load the dataset into memory') parser.add_argument('-join_vocab', action='store_true', help='Share dictionary for source and target') parser.add_argument('-batch_size_words', type=int, default=2048, help='Maximum number of words in a batch') parser.add_argument('-batch_size_sents', type=int, default=128, help='Maximum number of sentences in a batch') parser.add_argument('-batch_size_multiplier', type=int, default=1, help='Number of sentences in a batch must be divisible by this number') parser.add_argument('-pad_count', action='store_true', help='Count padding words when batching') parser.add_argument('-src_seq_length', type=int, default=64, help='Discard examples with a source sequence length above this value') parser.add_argument('-src_seq_length_trunc', type=int, default=0, help='Truncate source sequences to this length. 0 (default) to disable') parser.add_argument('-tgt_seq_length', type=int, default=64, help='Discard examples with a target sequence length above this value') parser.add_argument('-tgt_seq_length_trunc', type=int, default=0, help='Truncate target sequences to this length. 0 (default) to disable') @classmethod def add_eval_options(cls, parser): super().add_eval_options(parser) @staticmethod def preprocess(args): split_words = args.input_type == 'word' # since input and output dir are the same, this is no longer needed os.makedirs(args.data_dir_out, exist_ok=True) train_src_name = os.path.basename(args.train_src) train_tgt_name = os.path.basename(args.train_tgt) (src_offsets, src_lengths, src_counter), \ (tgt_offsets, tgt_lengths, tgt_counter) = \ get_indices_and_vocabulary((args.train_src, args.train_tgt), split_words, args.lower, not args.no_progress, args.report_every) out_offsets_src = os.path.join(args.data_dir_out, train_src_name + '.idx.npy') out_lengths_src = os.path.join(args.data_dir_out, train_tgt_name + '.len.npy') np.save(out_offsets_src, src_offsets) np.save(out_lengths_src, src_lengths) if args.src_vocab is not None: src_dictionary = Dictionary.load(args.src_vocab) else: src_dictionary = Dictionary() for word, count in src_counter.items(): src_dictionary.add_symbol(word, count) out_offsets_tgt = os.path.join(args.data_dir_out, train_src_name + '.idx.npy') out_lengths_tgt = os.path.join(args.data_dir_out, train_tgt_name + '.len.npy') np.save(out_offsets_tgt, tgt_offsets) np.save(out_lengths_tgt, tgt_lengths) if args.tgt_vocab is not None: tgt_dictionary = Dictionary.load(args.tgt_vocab) else: tgt_dictionary = Dictionary() for word, count in tgt_counter.items(): tgt_dictionary.add_symbol(word, count) if args.join_vocab: # If we explicitly load a target dictionary to merge # or we are inferring both dictionaries if args.tgt_vocab is not None or args.src_vocab is None: src_dictionary.update(tgt_dictionary) src_dictionary.finalize(nwords=args.src_vocab_size, threshold=args.vocab_threshold or -1) src_dictionary.save(os.path.join(args.data_dir_out, 'dict')) else: src_dictionary.finalize(nwords=args.src_vocab_size, threshold=args.vocab_threshold or -1) tgt_dictionary.finalize(nwords=args.tgt_vocab_size, threshold=args.vocab_threshold or -1) src_dictionary.save(os.path.join(args.data_dir_out, 'src.dict')) tgt_dictionary.save(os.path.join(args.data_dir_out, 'tgt.dict')) def __init__(self, args): super().__init__(args) if hasattr(args, 'data_dir'): logger.info('Loading vocabularies from {}'.format(args.data_dir)) if args.join_vocab: self.src_dict = Dictionary.load(os.path.join(args.data_dir, 'dict')) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary.load(os.path.join(args.data_dir, 'src.dict')) self.tgt_dict = Dictionary.load(os.path.join(args.data_dir, 'tgt.dict')) self.loss = self._build_loss() logger.info('Vocabulary size: {:,d}|{:,d}'.format(len(self.src_dict), len(self.tgt_dict))) else: self.src_dict = None self.tgt_dict = None def online_translate(self, model_or_ensemble, in_stream): models = model_or_ensemble if not isinstance(models, Sequence): models = [model_or_ensemble] for model in models: model.eval() split_words = self.args.input_type == 'words' generator = SequenceGenerator(models, self.tgt_dict, models[0].batch_first, self.args.beam_size, maxlen_b=20, normalize_scores=self.args.normalize, len_penalty=self.args.alpha, unk_penalty=self.args.beta, diverse_beam_strength=self.args.diverse_beam_strength) join_str = ' ' if self.args.input_type == 'word' else '' for line in in_stream: line = line.rstrip() if self.args.lower: line = line.lower() if split_words: line = line.split(' ') src_indices = self.src_dict.to_indices(line, bos=False, eos=False) encoder_inputs = src_indices.unsqueeze(0 if self.batch_first else 1) source_lengths = torch.tensor([len(line)]) encoder_mask = encoder_inputs.ne(self.src_dict.pad()) if self.args.cuda: encoder_inputs = encoder_inputs.cuda() source_lengths = source_lengths.cuda() encoder_mask = encoder_mask.cuda() res = [self.tgt_dict.string(tr['tokens'], join_str=join_str) for tr in generator.generate(encoder_inputs, source_lengths, encoder_mask)[0][:self.args.n_best]] if self.args.print_translations: tqdm.write(line) for i, hyp in enumerate(res): tqdm.write("Hyp {}/{}: {}".format(i + 1, len(hyp), hyp)) if len(res) == 1: res = res[0] yield res def _build_loss(self): loss = NMTLoss(len(self.tgt_dict), self.tgt_dict.pad(), self.args.label_smoothing) if self.args.cuda: loss.cuda() return loss def _build_model(self, args): model = super()._build_model(args) logger.info('Building embeddings and softmax') return NMTModel.wrap_model(args, model, self.src_dict, self.tgt_dict) def load_data(self, model_args=None): logger.info('Loading training data') split_words = self.args.input_type == 'word' train_src_name = os.path.basename(self.args.train_src) train_tgt_name = os.path.basename(self.args.train_tgt) if self.args.load_into_memory: src_data = TextLineDataset.load_into_memory(self.args.train_src) tgt_data = TextLineDataset.load_into_memory(self.args.train_tgt) else: offsets_src = os.path.join(self.args.data_dir, train_src_name + '.idx.npy') offsets_tgt = os.path.join(self.args.data_dir, train_tgt_name + '.idx.npy') src_data = TextLineDataset.load_indexed(self.args.train_src, offsets_src) tgt_data = TextLineDataset.load_indexed(self.args.train_tgt, offsets_tgt) src_data = TextLookupDataset(src_data, self.src_dict, words=split_words, bos=False, eos=False, trunc_len=self.args.src_seq_length_trunc, lower=self.args.lower) tgt_data = TextLookupDataset(tgt_data, self.tgt_dict, words=split_words, bos=True, eos=True, trunc_len=self.args.tgt_seq_length_trunc, lower=self.args.lower) dataset = ParallelDataset(src_data, tgt_data) logger.info('Number of training sentences: {:,d}'.format(len(dataset))) src_len_filename = os.path.join(self.args.data_dir, train_src_name + '.len.npy') tgt_len_filename = os.path.join(self.args.data_dir, train_tgt_name + '.len.npy') src_lengths = np.load(src_len_filename) tgt_lengths = np.load(tgt_len_filename) def filter_fn(i): return src_lengths[i] <= self.args.src_seq_length and tgt_lengths[i] <= self.args.tgt_seq_length logger.info('Generating batches') batches = data_utils.generate_length_based_batches_from_lengths( np.maximum(src_lengths, tgt_lengths), self.args.batch_size_words, self.args.batch_size_sents, self.args.batch_size_multiplier, self.args.pad_count, key_fn=lambda i: (tgt_lengths[i], src_lengths[i]), filter_fn=filter_fn) logger.info('Number of training batches: {:,d}'.format(len(batches))) filtered = len(src_lengths) - sum(len(batch) for batch in batches) logger.info('Filtered {:,d}/{:,d} training examples for length'.format(filtered, len(src_lengths))) sampler = PreGeneratedBatchSampler(batches, self.args.curriculum == 0) model = self.build_model(model_args) params = list(filter(lambda p: p.requires_grad, model.parameters())) lr_scheduler, optimizer = self._build_optimizer(params) return TrainData(model, dataset, sampler, lr_scheduler, optimizer, self._get_training_metrics()) def _get_loss(self, model, batch) -> (Tensor, float): encoder_input = batch.get('src_indices') decoder_input = batch.get('tgt_input') targets = batch.get('tgt_output') if not model.batch_first: encoder_input = encoder_input.transpose(0, 1).contiguous() decoder_input = decoder_input.transpose(0, 1).contiguous() targets = targets.transpose(0, 1).contiguous() decoder_mask = decoder_input.ne(self.tgt_dict.pad()) logits = model(encoder_input, decoder_input, decoder_mask=decoder_mask, optimized_decoding=True) targets = targets.masked_select(decoder_mask) lprobs = model.get_normalized_probs(logits, log_probs=True) return self.loss(lprobs, targets) def _get_batch_weight(self, batch): return batch['tgt_size'] def _get_training_metrics(self): meters = super()._get_training_metrics() meters['srctok'] = AverageMeter() meters['tgttok'] = AverageMeter() return meters def _update_training_metrics(self, train_data, batch): meters = train_data.meters batch_time = meters['fwbw_wall'].val src_tokens = batch['src_size'] tgt_tokens = batch['tgt_size'] meters['srctok'].update(src_tokens, batch_time) meters['tgttok'].update(tgt_tokens, batch_time) return ['{:5.0f}|{:5.0f} tok/s'.format(meters['srctok'].avg, meters['tgttok'].avg)] def _reset_training_metrics(self, train_data): meters = train_data.meters meters['srctok'].reset() meters['tgttok'].reset() super()._reset_training_metrics(train_data) def solve(self, model_or_ensemble, task): models = model_or_ensemble if not isinstance(models, Sequence): models = [model_or_ensemble] for model in models: model.eval() generator = SequenceGenerator(models, self.tgt_dict, models[0].batch_first, self.args.beam_size, maxlen_b=20, normalize_scores=self.args.normalize, len_penalty=self.args.alpha, unk_penalty=self.args.beta, diverse_beam_strength=self.args.diverse_beam_strength) iterator = self._get_eval_iterator(task) join_str = ' ' if self.args.input_type == 'word' else '' results = [] for batch in tqdm(iterator, desc='inference', disable=self.args.no_progress): encoder_inputs = batch['src_indices'] if not generator.batch_first: encoder_inputs = encoder_inputs.transpose(0, 1) source_lengths = batch['src_lengths'] encoder_mask = encoder_inputs.ne(self.src_dict.pad()) res = [self.tgt_dict.string(tr['tokens'], join_str=join_str) for beams in generator.generate(encoder_inputs, source_lengths, encoder_mask) for tr in beams[:self.args.n_best]] if self.args.print_translations: for i in range(len(batch['src_indices'])): reference = batch['src_indices'][i][:batch['src_lengths'][i]] reference = self.src_dict.string(reference, join_str=join_str, bpe_symbol=self.args.bpe_symbol) tqdm.write("Ref {}: {}".format(len(results) + i, reference)) for j in range(self.args.n_best): translation = res[i * self.args.n_best + j] tqdm.write("Hyp {}.{}: {}".format(len(results) + i, j + 1, translation.replace(self.args.bpe_symbol, ''))) results.extend(res) return results def _get_eval_iterator(self, task): split_words = self.args.input_type == 'word' src_data = TextLookupDataset(task.src_dataset, self.src_dict, split_words, bos=False, eos=False, lower=self.args.lower) tgt_data = None if task.tgt_dataset is not None: tgt_data = TextLookupDataset(task.tgt_dataset, self.tgt_dict, split_words, lower=self.args.lower) dataset = ParallelDataset(src_data, tgt_data) return dataset.get_iterator(batch_size=self.args.batch_size, num_workers=self.args.data_loader_threads, cuda=self.args.cuda) def state_dict(self): res = super().state_dict() if self.args.join_vocab: res['dict'] = self.src_dict.state_dict() else: res['src_dict'] = self.src_dict.state_dict() res['tgt_dict'] = self.tgt_dict.state_dict() return res def load_args(self, args): self.args.join_vocab = args.join_vocab self.args.input_type = args.input_type def load_state_dict(self, state_dict): super().load_state_dict(state_dict) if self.args.join_vocab: self.src_dict = Dictionary() self.src_dict.load_state_dict(state_dict['dict']) self.tgt_dict = self.src_dict else: self.src_dict = Dictionary() self.src_dict.load_state_dict(state_dict['src_dict']) self.tgt_dict = Dictionary() self.tgt_dict.load_state_dict(state_dict['tgt_dict']) self.loss = self._build_loss()
output_dict = quan_transformer(inputs) outputs_quan = generator(output_dict["hiddens"], False).clone().detach().cpu() loss_quan = loss_function_quan(output_dict, decoder_input, generator, backward=True)['loss'].clone().detach().cpu() grads_quan = encoder.layer_modules[ 0].multihead.fc_query.function.linear.weight.grad.clone().detach().cpu() grads_quan2 = decoder.layer_modules[ -1].multihead_src.fc_concat.function.linear.weight.grad.clone().detach( ).cpu() optim.zero_grad() print("Making Felix Transformer") dictionary = Dictionary() dictionary.pad_index = onmt.Constants.PAD felix_transformer = Transformer.build_model(args) felix_transformer = EncoderDecoderModel( NMTEncoder(felix_transformer.encoder, embedding_src, args.word_dropout), NMTDecoder(felix_transformer.decoder, embedding_tgt, args.word_dropout, generator.linear)) felix_transformer.eval() loss_function_felix = NMTLoss(30000, onmt.Constants.PAD, 0.0) felix_transformer.cuda() loss_function_felix.cuda() print(len(list(felix_transformer.parameters())), len(list(quan_transformer.parameters()))) print(sum(p.numel() for p in felix_transformer.parameters())) print(sum(p.numel() for p in quan_transformer.parameters()))
def preprocess(args): split_words = args.input_type == 'word' # since input and output dir are the same, this is no longer needed os.makedirs(args.data_dir_out, exist_ok=True) train_src_name = os.path.basename(args.train_src) train_tgt_name = os.path.basename(args.train_tgt) (src_offsets, src_lengths, src_counter), \ (tgt_offsets, tgt_lengths, tgt_counter) = \ get_indices_and_vocabulary((args.train_src, args.train_tgt), split_words, args.lower, not args.no_progress, args.report_every) out_offsets_src = os.path.join(args.data_dir_out, train_src_name + '.idx.npy') out_lengths_src = os.path.join(args.data_dir_out, train_tgt_name + '.len.npy') np.save(out_offsets_src, src_offsets) np.save(out_lengths_src, src_lengths) if args.src_vocab is not None: src_dictionary = Dictionary.load(args.src_vocab) else: src_dictionary = Dictionary() for word, count in src_counter.items(): src_dictionary.add_symbol(word, count) out_offsets_tgt = os.path.join(args.data_dir_out, train_src_name + '.idx.npy') out_lengths_tgt = os.path.join(args.data_dir_out, train_tgt_name + '.len.npy') np.save(out_offsets_tgt, tgt_offsets) np.save(out_lengths_tgt, tgt_lengths) if args.tgt_vocab is not None: tgt_dictionary = Dictionary.load(args.tgt_vocab) else: tgt_dictionary = Dictionary() for word, count in tgt_counter.items(): tgt_dictionary.add_symbol(word, count) if args.join_vocab: # If we explicitly load a target dictionary to merge # or we are inferring both dictionaries if args.tgt_vocab is not None or args.src_vocab is None: src_dictionary.update(tgt_dictionary) src_dictionary.finalize(nwords=args.src_vocab_size, threshold=args.vocab_threshold or -1) src_dictionary.save(os.path.join(args.data_dir_out, 'dict')) else: src_dictionary.finalize(nwords=args.src_vocab_size, threshold=args.vocab_threshold or -1) tgt_dictionary.finalize(nwords=args.tgt_vocab_size, threshold=args.vocab_threshold or -1) src_dictionary.save(os.path.join(args.data_dir_out, 'src.dict')) tgt_dictionary.save(os.path.join(args.data_dir_out, 'tgt.dict'))
import argparse from nmtg.data import Dictionary from nmtg.data.noisy_text import NoisyTextDataset from nmtg.data.text_lookup_dataset import TextLookupDataset from nmtg.tasks.denoising_text_task import DenoisingTextTask parser = argparse.ArgumentParser() DenoisingTextTask.add_options(parser) args = parser.parse_args() task = DenoisingTextTask.setup_task(args) dictionary = Dictionary.infer_from_text(task.tgt_dataset) noisy_text = NoisyTextDataset(TextLookupDataset(task.src_dataset, dictionary, True, args.lower, False, False, False), args.word_shuffle, args.noise_word_dropout, args.word_blank, args.bpe_symbol) for i in range(len(noisy_text)): print(task.tgt_dataset[i]) print(dictionary.string(noisy_text[i])) input()
from tqdm import tqdm from nmtg.data import Dictionary parser = argparse.ArgumentParser() parser.add_argument('input') parser.add_argument('other_language_dict') parser.add_argument('output', default='-', nargs='?') parser.add_argument('-threshold', type=int, default=0) parser.add_argument('-prob', type=float, default=0.1) parser.add_argument('-num_variants', type=int, default=1) args = parser.parse_args() with open(args.input) as f: main_dictionary = Dictionary.infer_from_text(f) main_symbols = main_dictionary.symbols[main_dictionary.nspecial:] del main_dictionary dictionary = Dictionary.load(args.other_language_dict) if args.threshold != 0: dictionary.finalize(threshold=args.threshold) symbols = dictionary.symbols[dictionary.nspecial:] del dictionary def get_nearest(pool, symbol): return symbol, min(pool, key=lambda x: editdistance.eval(x, symbol)) partial = functools.partial(get_nearest, symbols)