예제 #1
0
def build_model(opt, dicts):
    # adding missing options if the opt was built before. (for loading old models)
    opt = backward_compatible(opt)

    onmt.constants.layer_norm = opt.layer_norm
    onmt.constants.weight_norm = opt.weight_norm
    onmt.constants.activation_layer = opt.activation_layer
    onmt.constants.version = 1.0
    onmt.constants.attention_out = opt.attention_out
    onmt.constants.residual_type = opt.residual_type
    onmt.constants.fused_ffn = opt.fused_ffn
    opt.nce = opt.nce_noise > 0

    if 'langs' not in dicts:
        dicts['langs'] = {'src': 0, 'tgt': 1}
    opt.n_languages = len(dicts['langs'])

    if opt.bayes_by_backprop:
        from onmt.bayesian_factory import build_model as build_bayesian_model
        model = build_bayesian_model(opt, dicts)
        return model

    if not opt.fusion:
        model = build_tm_model(opt, dicts)
    else:
        raise NotImplementedError
        model = build_fusion(opt, dicts)

    return model
예제 #2
0
def build_model(opt, dicts):
    opt = backward_compatible(opt)

    onmt.constants.layer_norm = opt.layer_norm
    onmt.constants.weight_norm = opt.weight_norm
    onmt.constants.activation_layer = opt.activation_layer
    onmt.constants.version = 1.0
    onmt.constants.attention_out = opt.attention_out
    onmt.constants.residual_type = opt.residual_type

    if not opt.fusion:
        model = build_tm_model(opt, dicts)
    else:
        raise NotImplementedError
        model = build_fusion(opt, dicts)

    return model
예제 #3
0
def build_language_model(opt, dicts):
    opt = backward_compatible(opt)

    onmt.constants.layer_norm = opt.layer_norm
    onmt.constants.weight_norm = opt.weight_norm
    onmt.constants.activation_layer = opt.activation_layer
    onmt.constants.version = 1.0
    onmt.constants.attention_out = opt.attention_out
    onmt.constants.residual_type = opt.residual_type

    from onmt.models.transformer_xl import TransformerXL

    embedding_tgt = nn.Embedding(dicts['tgt'].size(),
                                 opt.model_size,
                                 padding_idx=onmt.constants.TGT_PAD)

    if opt.use_language_embedding:
        print("* Create language embeddings with %d languages" %
              len(dicts['langs']))
        language_embeddings = nn.Embedding(len(dicts['langs']), opt.model_size)
    else:
        language_embeddings = None

    generators = [
        onmt.modules.base_seq2seq.Generator(opt.model_size,
                                            dicts['tgt'].size())
    ]

    model = TransformerXL(opt,
                          embedding_tgt,
                          nn.ModuleList(generators),
                          language_embeddings=language_embeddings)

    model.tgt_dict = dicts['tgt']

    if opt.tie_weights:
        print("* Joining the weights of decoder input and output embeddings")
        model.tie_weights()

    return model
onmt.markdown.add_md_help_argument(parser)

parser.add_argument('-model_src', required=True,
                    help='Path to model .pt file')
parser.add_argument('-model_tgt', required=True,
                    help='Path to model .pt file')
parser.add_argument('-model_out', required=True,
                    help='Path to model .pt file')

opt = parser.parse_args()
# first, we load the model src
print(opt.model_src)
checkpoint = torch.load(opt.model_src, map_location=lambda storage, loc: storage)

model_opt = checkpoint['opt']
model_opt = backward_compatible(model_opt)

src_dicts = checkpoint['dicts']
# update special tokens
onmt.constants = add_tokenidx(model_opt, onmt.constants, src_dicts)

model = build_model(model_opt, checkpoint['dicts'])
model.load_state_dict(checkpoint['model'])

# now load the 2nd model
print(opt.model_tgt)
checkpoint = torch.load(opt.model_tgt, map_location=lambda storage, loc: storage)
# model_opt = checkpoint['opt']
# model_opt = backward_compatible(model_opt)
tgt_dicts = checkpoint['dicts']
예제 #5
0
    def __init__(self, opt):
        self.opt = opt
        self.tt = torch.cuda if opt.cuda else torch
        self.beam_accum = None
        self.beta = opt.beta
        self.alpha = opt.alpha
        self.start_with_bos = opt.start_with_bos
        self.fp16 = opt.fp16
        self.attributes = opt.attributes  # attributes split by |. for example: de|domain1
        # self.bos_token = opt.bos_token
        self.sampling = opt.sampling
        self.src_lang = opt.src_lang
        self.tgt_lang = opt.tgt_lang

        if self.attributes:
            self.attributes = self.attributes.split("|")

        self.models = list()
        self.model_types = list()

        # models are string with | as delimiter
        models = opt.model.split("|")

        print(models)
        self.n_models = len(models)
        self._type = 'text'

        for i, model_path in enumerate(models):
            checkpoint = torch.load(model_path,
                                    map_location=lambda storage, loc: storage)

            model_opt = checkpoint['opt']
            model_opt = backward_compatible(model_opt)
            if hasattr(model_opt, "enc_state_dict"):
                model_opt.enc_state_dict = None
                model_opt.dec_state_dict = None

            self.main_model_opt = model_opt
            dicts = checkpoint['dicts']

            # update special tokens
            onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts)
            self.bos_token = model_opt.tgt_bos_word

            if i == 0:
                if "src" in checkpoint['dicts']:
                    self.src_dict = checkpoint['dicts']['src']
                else:
                    self._type = "audio"
                    # self.src_dict = self.tgt_dict

                self.tgt_dict = checkpoint['dicts']['tgt']

                if "langs" in checkpoint["dicts"]:
                    self.lang_dict = checkpoint['dicts']['langs']

                else:
                    self.lang_dict = {'src': 0, 'tgt': 1}

                self.bos_id = self.tgt_dict.labelToIdx[self.bos_token]

            model = build_model(model_opt, checkpoint['dicts'])
            optimize_model(model)
            if opt.verbose:
                print('Loading model from %s' % model_path)
            model.load_state_dict(checkpoint['model'])

            if model_opt.model in model_list:
                # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length:
                #     print("Not enough len to decode. Renewing .. ")
                #     model.decoder.renew_buffer(self.opt.max_sent_length)
                model.renew_buffer(self.opt.max_sent_length)

            # model.convert_autograd()

            if opt.fp16:
                model = model.half()

            if opt.cuda:
                model = model.cuda()
            else:
                model = model.cpu()

            if opt.dynamic_quantile == 1:

                engines = torch.backends.quantized.supported_engines
                if 'fbgemm' in engines:
                    torch.backends.quantized.engine = 'fbgemm'
                else:
                    print(
                        "[INFO] fbgemm is not found in the available engines. Possibly the CPU does not support AVX2."
                        " It is recommended to disable Quantization (set to 0)."
                    )
                    torch.backends.quantized.engine = 'qnnpack'

                # convert the custom functions to their autograd equivalent first
                model.convert_autograd()

                model = torch.quantization.quantize_dynamic(
                    model, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)

            model.eval()

            self.models.append(model)
            self.model_types.append(model_opt.model)

        # language model
        if opt.lm is not None:
            if opt.verbose:
                print('Loading language model from %s' % opt.lm)

            lm_chkpoint = torch.load(opt.lm,
                                     map_location=lambda storage, loc: storage)

            lm_opt = lm_chkpoint['opt']

            lm_model = build_language_model(lm_opt, checkpoint['dicts'])

            if opt.fp16:
                lm_model = lm_model.half()

            if opt.cuda:
                lm_model = lm_model.cuda()
            else:
                lm_model = lm_model.cpu()

            self.lm_model = lm_model

        self.cuda = opt.cuda
        self.ensemble_op = opt.ensemble_op

        if opt.autoencoder is not None:
            if opt.verbose:
                print('Loading autoencoder from %s' % opt.autoencoder)
            checkpoint = torch.load(opt.autoencoder,
                                    map_location=lambda storage, loc: storage)
            model_opt = checkpoint['opt']

            # posSize= checkpoint['autoencoder']['nmt.decoder.positional_encoder.pos_emb'].size(0)
            # self.models[0].decoder.renew_buffer(posSize)
            # self.models[0].decoder.renew_buffer(posSize)

            # Build model from the saved option
            self.autoencoder = Autoencoder(self.models[0], model_opt)

            self.autoencoder.load_state_dict(checkpoint['autoencoder'])

            if opt.cuda:
                self.autoencoder = self.autoencoder.cuda()
                self.models[0] = self.models[0].cuda()
            else:
                self.autoencoder = self.autoencoder.cpu()
                self.models[0] = self.models[0].cpu()

            self.models[0].autoencoder = self.autoencoder
        if opt.verbose:
            print('Done')
예제 #6
0
    def __init__(self, opt):

        super().__init__(opt)

        # self.eos = onmt.constants.EOS
        # self.pad = onmt.constants.PAD
        # self.bos = self.bos_id

        self.src_bos = onmt.constants.SRC_BOS
        self.src_eos = onmt.constants.SRC_EOS
        self.src_pad = onmt.constants.SRC_PAD
        self.src_unk = onmt.constants.SRC_UNK

        self.tgt_bos = self.bos_id
        self.tgt_pad = onmt.constants.TGT_PAD
        self.tgt_eos = onmt.constants.TGT_EOS
        self.tgt_unk = onmt.constants.TGT_UNK

        self.search = BeamSearch(self.tgt_dict)

        self.vocab_size = self.tgt_dict.size()
        self.min_len = 1
        self.normalize_scores = opt.normalize
        self.len_penalty = opt.alpha
        self.buffering = not opt.no_buffering
        # self.buffering = False  # buffering is currently bugged

        if hasattr(opt, 'no_repeat_ngram_size'):
            self.no_repeat_ngram_size = opt.no_repeat_ngram_size
        else:
            self.no_repeat_ngram_size = 0

        if hasattr(opt, 'dynamic_max_len'):
            self.dynamic_max_len = opt.dynamic_max_len
        else:
            self.dynamic_max_len = False

        if hasattr(opt, 'dynamic_max_len_scale'):
            self.dynamic_max_len_scale = opt.dynamic_max_len_scale
        else:
            self.dynamic_max_len_scale = 1.2

        if opt.verbose:
            # print('* Current bos id is: %d, default bos id is: %d' % (self.tgt_bos, onmt.constants.BOS))
            print(
                "src bos id is %d; src eos id is %d;  src pad id is %d; src unk id is %d"
                % (self.src_bos, self.src_eos, self.src_pad, self.src_unk))
            print(
                "tgt bos id is %d; tgt eos id is %d;  tgt_pad id is %d; tgt unk id is %d"
                % (self.tgt_bos, self.tgt_eos, self.tgt_pad, self.tgt_unk))
            print('* Using fast beam search implementation')

        if opt.vocab_list:
            word_list = list()
            for line in open(opt.vocab_list).readlines():
                word = line.strip()
                word_list.append(word)

            self.filter = torch.Tensor(self.tgt_dict.size()).zero_()
            for word_idx in [self.tgt_eos, self.tgt_unk]:
                self.filter[word_idx] = 1

            for word in word_list:
                idx = self.tgt_dict.lookup(word)
                if idx is not None:
                    self.filter[idx] = 1

            self.filter = self.filter.bool()
            # print(self.filter)
            if opt.cuda:
                self.filter = self.filter.cuda()

            self.use_filter = True
        else:
            self.use_filter = False

        if opt.sub_model:
            self.sub_models = list()
            self.sub_model_types = list()

            # models are string with | as delimiter
            sub_models = opt.sub_model.split("|")

            print("Loading sub models ... ")
            self.n_sub_models = len(sub_models)
            self.sub_type = 'text'

            for i, model_path in enumerate(sub_models):
                checkpoint = torch.load(
                    model_path, map_location=lambda storage, loc: storage)

                model_opt = checkpoint['opt']
                model_opt = backward_compatible(model_opt)
                if hasattr(model_opt, "enc_not_load_state"):
                    model_opt.enc_not_load_state = True
                    model_opt.dec_not_load_state = True

                dicts = checkpoint['dicts']

                # update special tokens
                onmt.constants = add_tokenidx(model_opt, onmt.constants, dicts)
                # self.bos_token = model_opt.tgt_bos_word
                """"BE CAREFUL: the sub-models might mismatch with the main models in terms of language dict"""
                """"REQUIRE RE-matching"""

                if i == 0:
                    if "src" in checkpoint['dicts']:
                        self.src_dict = checkpoint['dicts']['src']
                #     else:
                #         self._type = "audio"
                #     self.tgt_dict = checkpoint['dicts']['tgt']
                #
                #     if "langs" in checkpoint["dicts"]:
                #         self.lang_dict = checkpoint['dicts']['langs']
                #
                #     else:
                #         self.lang_dict = {'src': 0, 'tgt': 1}
                #
                #     self.bos_id = self.tgt_dict.labelToIdx[self.bos_token]
                if opt.verbose:
                    print('Loading sub-model from %s' % model_path)

                model = build_model(model_opt, checkpoint['dicts'])
                optimize_model(model)
                model.load_state_dict(checkpoint['model'])

                if model_opt.model in model_list:
                    # if model.decoder.positional_encoder.len_max < self.opt.max_sent_length:
                    #     print("Not enough len to decode. Renewing .. ")
                    #     model.decoder.renew_buffer(self.opt.max_sent_length)
                    model.renew_buffer(self.opt.max_sent_length)

                if opt.fp16:
                    model = model.half()

                if opt.cuda:
                    model = model.cuda()
                else:
                    model = model.cpu()

                if opt.dynamic_quantile == 1:

                    engines = torch.backends.quantized.supported_engines
                    if 'fbgemm' in engines:
                        torch.backends.quantized.engine = 'fbgemm'
                    else:
                        print(
                            "[INFO] fbgemm is not found in the available engines. "
                            " Possibly the CPU does not support AVX2."
                            " It is recommended to disable Quantization (set to 0)."
                        )
                        torch.backends.quantized.engine = 'qnnpack'

                    model = torch.quantization.quantize_dynamic(
                        model, {torch.nn.LSTM, torch.nn.Linear},
                        dtype=torch.qint8)

                model.eval()

                self.sub_models.append(model)
                self.sub_model_types.append(model_opt.model)
        else:
            self.n_sub_models = 0
            self.sub_models = []

        if opt.ensemble_weight:
            ensemble_weight = [
                float(item) for item in opt.ensemble_weight.split("|")
            ]
            assert len(ensemble_weight) == self.n_models

            if opt.sub_ensemble_weight:
                sub_ensemble_weight = [
                    float(item) for item in opt.sub_ensemble_weight.split("|")
                ]
                assert len(sub_ensemble_weight) == self.n_sub_models
                ensemble_weight = ensemble_weight + sub_ensemble_weight

            total = sum(ensemble_weight)
            self.ensemble_weight = [item / total for item in ensemble_weight]
        else:
            self.ensemble_weight = None

        print(self.main_model_opt)