Пример #1
0
    def preprocess(cls, args):
        split_words = args.input_type == 'word'

        os.makedirs(args.data_dir_out, exist_ok=True)
        src_name = os.path.basename(args.in_src)
        tgt_name = os.path.basename(args.in_tgt)

        (src_offsets, src_lengths, src_counter), \
        (tgt_offsets, tgt_lengths, tgt_counter) = \
            cls.get_indices_and_vocabulary((args.in_src, args.in_tgt),
                                           split_words,
                                           args.lower,
                                           not args.no_progress,
                                           args.report_every)

        out_offsets_src = os.path.join(args.data_dir_out,
                                       src_name + '.idx.npy')
        out_lengths_src = os.path.join(args.data_dir_out,
                                       src_name + '.len.npy')
        np.save(out_offsets_src, src_offsets)
        np.save(out_lengths_src, src_lengths)

        src_dictionary = Dictionary()
        for word, count in src_counter.items():
            src_dictionary.add_symbol(word, count)

        out_offsets_tgt = os.path.join(args.data_dir_out,
                                       tgt_name + '.idx.npy')
        out_lengths_tgt = os.path.join(args.data_dir_out,
                                       tgt_name + '.len.npy')
        np.save(out_offsets_tgt, tgt_offsets)
        np.save(out_lengths_tgt, tgt_lengths)

        tgt_dictionary = Dictionary()
        for word, count in tgt_counter.items():
            tgt_dictionary.add_symbol(word, count)

        if args.join_vocab:
            # If we explicitly load a target dictionary to merge
            # or we are inferring both dictionaries
            # if args.tgt_vocab is not None or args.src_vocab is None:
            #     src_dictionary.update(tgt_dictionary)
            src_dictionary.finalize(nwords=args.src_vocab_size,
                                    threshold=args.vocab_threshold or -1)

            src_dictionary.save(os.path.join(args.data_dir_out, 'dict'))

        else:
            src_dictionary.finalize(nwords=args.src_vocab_size,
                                    threshold=args.vocab_threshold or -1)
            tgt_dictionary.finalize(nwords=args.tgt_vocab_size,
                                    threshold=args.vocab_threshold or -1)

            src_dictionary.save(os.path.join(args.data_dir_out, 'src.dict'))
            tgt_dictionary.save(os.path.join(args.data_dir_out, 'tgt.dict'))
Пример #2
0
 def load_state_dict(self, state_dict):
     super().load_state_dict(state_dict)
     if self.args.join_vocab:
         self.src_dict = Dictionary()
         self.src_dict.load_state_dict(state_dict['dict'])
         self.tgt_dict = self.src_dict
     else:
         self.src_dict = Dictionary()
         self.src_dict.load_state_dict(state_dict['src_dict'])
         self.tgt_dict = Dictionary()
         self.tgt_dict.load_state_dict(state_dict['tgt_dict'])
     self.loss = self._build_loss()
Пример #3
0
    def _load_data(self, checkpoint):
        super()._load_data(checkpoint)
        args = checkpoint['args']
        self.args.join_vocab = args.join_vocab

        if args.join_vocab:
            self.src_dict = Dictionary()
            self.src_dict.load_state_dict(checkpoint['dict'])
            self.tgt_dict = self.src_dict
        else:
            self.src_dict = Dictionary()
            self.src_dict.load_state_dict(checkpoint['src_dict'])
            self.tgt_dict = Dictionary()
            self.tgt_dict.load_state_dict(checkpoint['tgt_dict'])
        self._build_loss()
Пример #4
0
    def preprocess(cls, args):
        split_words = args.input_type == 'word'

        os.makedirs(args.data_dir_out, exist_ok=True)
        basename = os.path.basename(args.in_data)

        ((offsets, lengths,
          counter), ) = cls.get_indices_and_vocabulary([args.in_data],
                                                       split_words, args.lower,
                                                       not args.no_progress,
                                                       args.report_every)

        out_offsets = os.path.join(args.data_dir_out, basename + '.idx.npy')
        out_lengths = os.path.join(args.data_dir_out, basename + '.len.npy')
        np.save(out_offsets, offsets)
        np.save(out_lengths, lengths)

        dictionary = Dictionary()
        for word, count in counter.items():
            dictionary.add_symbol(word, count)

        dictionary.finalize(nwords=args.vocab_size,
                            threshold=args.vocab_threshold or -1)

        dictionary.save(os.path.join(args.data_dir_out, 'dict'))
Пример #5
0
    def preprocess(args):
        split_words = args.input_type == 'word'

        os.makedirs(args.data_dir_out, exist_ok=True)
        train_clean_name = os.path.basename(args.train_clean)
        source_files = [args.train_clean]
        if args.train_noisy is not None:
            source_files.append(args.train_noisy)

        outputs = get_indices_and_vocabulary(source_files, split_words,
                                             args.lower, not args.no_progress,
                                             args.report_every)

        if args.train_noisy is not None:
            train_noisy_name = os.path.basename(args.train_noisy)
            (offsets, lengths, counter), \
            (noisy_offsets, noisy_lengths, noisy_counter) = outputs
            counter.update(noisy_counter)

            noisy_offset_filename = os.path.join(args.data_dir_out,
                                                 train_noisy_name + '.idx.npy')
            np.save(noisy_offset_filename, noisy_offsets)
        else:
            ((offsets, lengths, counter), ) = outputs

        out_offsets = os.path.join(args.data_dir_out,
                                   train_clean_name + '.idx.npy')
        out_lengths = os.path.join(args.data_dir_out,
                                   train_clean_name + '.len.npy')
        np.save(out_offsets, offsets)
        np.save(out_lengths, lengths)
        if args.vocab is not None:
            dictionary = Dictionary.load(args.vocab)
        else:
            dictionary = Dictionary()
            for word, count in counter.items():
                dictionary.add_symbol(word, count)

        dictionary.finalize(nwords=args.vocab_size,
                            threshold=args.vocab_threshold or -1)
        dictionary.save(os.path.join(args.data_dir_out, 'dict'))
Пример #6
0
output_dict = quan_transformer(inputs)
outputs_quan = generator(output_dict["hiddens"], False).clone().detach().cpu()
loss_quan = loss_function_quan(output_dict,
                               decoder_input,
                               generator,
                               backward=True)['loss'].clone().detach().cpu()
grads_quan = encoder.layer_modules[
    0].multihead.fc_query.function.linear.weight.grad.clone().detach().cpu()
grads_quan2 = decoder.layer_modules[
    -1].multihead_src.fc_concat.function.linear.weight.grad.clone().detach(
    ).cpu()
optim.zero_grad()

print("Making Felix Transformer")

dictionary = Dictionary()
dictionary.pad_index = onmt.Constants.PAD
felix_transformer = Transformer.build_model(args)
felix_transformer = EncoderDecoderModel(
    NMTEncoder(felix_transformer.encoder, embedding_src, args.word_dropout),
    NMTDecoder(felix_transformer.decoder, embedding_tgt, args.word_dropout,
               generator.linear))
felix_transformer.eval()
loss_function_felix = NMTLoss(30000, onmt.Constants.PAD, 0.0)
felix_transformer.cuda()
loss_function_felix.cuda()

print(len(list(felix_transformer.parameters())),
      len(list(quan_transformer.parameters())))
print(sum(p.numel() for p in felix_transformer.parameters()))
print(sum(p.numel() for p in quan_transformer.parameters()))