Exemplo n.º 1
0
    def load_model(cls, loader: Loader, fields: Dict[str, Field]):
        model_opt = loader.params
        src_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                  dropout=model_opt.dropout,
                                  padding_idx=fields["src"].pad_id,
                                  vocab_size=len(fields["src"].vocab))

        if len(model_opt.vocab) == 2:
            tgt_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                      dropout=model_opt.dropout,
                                      padding_idx=fields["tgt"].pad_id,
                                      vocab_size=len(fields["tgt"].vocab))
        else:
            # use shared word embedding for source and target
            tgt_embedding = src_embedding

        encoder = Encoder(model_opt.layers, model_opt.heads,
                          model_opt.hidden_size, model_opt.dropout,
                          model_opt.ff_size, src_embedding)

        decoder = Decoder(model_opt.layers, model_opt.heads,
                          model_opt.hidden_size, model_opt.dropout,
                          model_opt.ff_size, tgt_embedding)

        generator = Generator(model_opt.hidden_size, len(fields["tgt"].vocab))

        model = cls(encoder, decoder, generator)
        if not loader.empty:
            model.load_state_dict(loader.checkpoint['model'])
        return model
    def load_model(cls,
                   model_opt,
                   pad_ids: Dict[str, int],
                   vocab_sizes: Dict[str, int],
                   checkpoint=None,
                   cls_id1=None,
                   cls_id2=None):
        source_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                     dropout=model_opt.dropout,
                                     padding_idx=pad_ids["src"],
                                     vocab_size=vocab_sizes["src"])

        target_embedding_task2 = Embedding(embedding_dim=model_opt.hidden_size,
                                           dropout=model_opt.dropout,
                                           padding_idx=pad_ids["task2_tgt"],
                                           vocab_size=vocab_sizes["task2_tgt"])
        if model_opt.mono:
            # 单语摘要,task1 share source embedding
            target_embedding_task1 = source_embedding
        else:
            target_embedding_task1 = target_embedding_task2

        encoder = Encoder(model_opt.layers, model_opt.heads,
                          model_opt.hidden_size, model_opt.dropout,
                          model_opt.ff_size, source_embedding)

        task1_decoder = Decoder(model_opt.layers, model_opt.heads,
                                model_opt.hidden_size, model_opt.dropout,
                                model_opt.ff_size, target_embedding_task1)

        task2_decoder = Decoder(model_opt.layers,
                                model_opt.heads,
                                model_opt.hidden_size,
                                model_opt.dropout,
                                model_opt.ff_size,
                                target_embedding_task2,
                                share_decoder=task1_decoder,
                                num_share_layer=4)

        task1_generator = Generator(model_opt.hidden_size,
                                    vocab_sizes["task1_tgt"])
        task2_generator = Generator(model_opt.hidden_size,
                                    vocab_sizes["task2_tgt"])

        task1_extractor = Extractor(model_opt.hidden_size)
        task2_extractor = Extractor(model_opt.hidden_size)

        model = cls(encoder, task1_decoder, task2_decoder, task1_generator,
                    task2_generator, task1_extractor, task2_extractor, cls_id1,
                    cls_id2)
        if checkpoint is None and model_opt.train_from:
            checkpoint = torch.load(model_opt.train_from,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint["model"])
        elif checkpoint is not None:
            model.load_state_dict(checkpoint)
        return model
Exemplo n.º 3
0
    def load_model(cls,
                   model_opt,
                   pad_ids: Dict[str, int],
                   vocab_sizes: Dict[str, int],
                   checkpoint=None):
        source_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                     dropout=model_opt.dropout,
                                     padding_idx=pad_ids["source"],
                                     vocab_size=vocab_sizes["source"])

        summary_en_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                         dropout=model_opt.dropout,
                                         padding_idx=pad_ids["summary_en"],
                                         vocab_size=vocab_sizes["summary_en"])

        if model_opt.share_cn_embedding:
            summary_cn_embedding = source_embedding
        else:
            summary_cn_embedding = Embedding(
                embedding_dim=model_opt.hidden_size,
                dropout=model_opt.dropout,
                padding_idx=pad_ids["summary_cn"],
                vocab_size=vocab_sizes["summary_cn"])

        encoder = Encoder(model_opt.layers, model_opt.heads,
                          model_opt.hidden_size, model_opt.dropout,
                          model_opt.ff_size, source_embedding)

        cn_decoder = Decoder(model_opt.layers, model_opt.heads,
                             model_opt.hidden_size, model_opt.dropout,
                             model_opt.ff_size, summary_cn_embedding)

        en_decoder = Decoder(model_opt.layers, model_opt.heads,
                             model_opt.hidden_size, model_opt.dropout,
                             model_opt.ff_size, summary_en_embedding)

        cn_generator = Generator(model_opt.hidden_size,
                                 vocab_sizes["summary_cn"])
        en_generator = Generator(model_opt.hidden_size,
                                 vocab_sizes["summary_en"])

        model = cls(encoder, cn_decoder, en_decoder, cn_generator,
                    en_generator)
        if checkpoint is None and model_opt.train_from:
            checkpoint = torch.load(model_opt.train_from,
                                    map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint["model"])
        elif checkpoint is not None:
            model.load_state_dict(checkpoint)
        return model
Exemplo n.º 4
0
    def load_model(cls, model_opt,
                   pad_ids: Dict[str, int],
                   vocab_sizes: Dict[str, int],
                   checkpoint=None):
        src_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                  dropout=model_opt.dropout,
                                  padding_idx=pad_ids["src"],
                                  vocab_size=vocab_sizes["src"])

        if len(model_opt.vocab) == 2:
            tgt_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                      dropout=model_opt.dropout,
                                      padding_idx=pad_ids["tgt"],
                                      vocab_size=vocab_sizes["tgt"])
        else:
            # use shared word embedding for source and target
            tgt_embedding = src_embedding

        encoder = Encoder(model_opt.layers,
                          model_opt.heads,
                          model_opt.hidden_size,
                          model_opt.dropout,
                          model_opt.ff_size,
                          src_embedding)

        decoder = Decoder(model_opt.layers,
                          model_opt.heads,
                          model_opt.hidden_size,
                          model_opt.dropout,
                          model_opt.ff_size,
                          tgt_embedding)

        transattn = TransAttn(model_opt.heads,
                              model_opt.hidden_size,
                              model_opt.dropout,
                              tgt_embedding)

        generator = Generator(model_opt.hidden_size, vocab_sizes["tgt"])

        model = cls(encoder, decoder, generator, transattn)

        if model_opt.train_from and checkpoint is None:
            checkpoint = torch.load(model_opt.train_from, map_location=lambda storage, loc: storage)
            model.load_state_dict(checkpoint["model"])
        elif checkpoint is not None:
            model.load_state_dict(checkpoint)
        return model
    def load_model(cls,
                   model_opt,
                   pad_ids: Dict[str, int],
                   vocab_sizes: Dict[str, int],
                   checkpoint=None):
        source_embedding = Embedding(embedding_dim=model_opt.hidden_size,
                                     dropout=model_opt.dropout,
                                     padding_idx=pad_ids["src"],
                                     vocab_size=vocab_sizes["src"])

        target_embedding_task2 = Embedding(embedding_dim=model_opt.hidden_size,
                                           dropout=model_opt.dropout,
                                           padding_idx=pad_ids["task2_tgt"],
                                           vocab_size=vocab_sizes["task2_tgt"])
        if model_opt.mono:
            # 单语摘要,task1 share source embedding
            target_embedding_task1 = source_embedding
        else:
            target_embedding_task1 = target_embedding_task2

        encoder = Encoder(model_opt.layers, model_opt.heads,
                          model_opt.hidden_size, model_opt.dropout,
                          model_opt.ff_size, source_embedding)

        task1_decoder = Decoder(model_opt.layers, model_opt.heads,
                                model_opt.hidden_size, model_opt.dropout,
                                model_opt.ff_size, target_embedding_task1)

        task2_decoder = Decoder(model_opt.layers, model_opt.heads,
                                model_opt.hidden_size, model_opt.dropout,
                                model_opt.ff_size, target_embedding_task2)

        task1_generator = Generator(model_opt.hidden_size,
                                    vocab_sizes["task1_tgt"])
        task2_generator = Generator(model_opt.hidden_size,
                                    vocab_sizes["task2_tgt"])

        model = cls(encoder, task1_decoder, task2_decoder, task1_generator,
                    task2_generator)
        if checkpoint is None and model_opt.train_from:
            checkpoint = torch.load(model_opt.train_from,
                                    map_location=lambda storage, loc: storage)
            #model.load_state_dict(checkpoint["model"])
            model_dict = model.state_dict()
            load_model_dict = {k: checkpoint['model'][k] for k in model_dict}
            model_dict.update(load_model_dict)
            #model_dict.update(checkpoint["model"])
            model.load_state_dict(model_dict)
        elif checkpoint is not None:
            try:
                model.load_state_dict(checkpoint)
            except:
                model_dict = model.state_dict()
                load_model_dict = {k: checkpoint[k] for k in model_dict}
                model_dict.update(load_model_dict)
                #model_dict.update(checkpoint["model"])
                model.load_state_dict(model_dict)
        elif checkpoint is not None:
            model.load_state_dict(checkpoint)
        return model
        return model