Пример #1
0
    def __init__(self, args, device, checkpoint):
        super(SentenceEncoder, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.model_name, args.temp_dir,
                         args.finetune_bert)
        self.ext_transformer_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)

        if args.max_pos > 512 and args.model_name == 'bert':
            my_pos_embeddings = nn.Embedding(
                args.max_pos, self.bert.model.config.hidden_size)
            import pdb
            pdb.set_trace()

            my_pos_embeddings.weight.data[:
                                          512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[
                512:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    -1][None, :].repeat(args.max_pos - 512, 1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        if args.max_pos > 4096 and args.model_name == 'longformer':
            my_pos_embeddings = nn.Embedding(
                args.max_pos + 2, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:
                                          4097] = self.bert.model.embeddings.position_embeddings.weight.data[:
                                                                                                             -1]
            my_pos_embeddings.weight.data[
                4097:] = self.bert.model.embeddings.position_embeddings.weight.data[
                    1:args.max_pos + 2 - 4096]
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        self.sigmoid = nn.Sigmoid()
Пример #2
0
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads,
                                               args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings


        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)
Пример #3
0
    def __init__(self, args, device, ckpt):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.bert_path, args.finetune_bert)
        self.ext_layer = ExtTransformerEncoder()
        """
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
            num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings
        """
        if ckpt is not None:
            self.load_state_dict(
                ckpt['model'], strict=True
            )  # 注意这里strict用于检测model和ckpt里的keys是否严格一一对应,false则可以放缓
        else:
            if args.param_init != 0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            elif args.param_init_glorot:  # 即选用xavier均匀分布
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)  # 关键,一定记住
Пример #4
0
    def __init__(self, device, checkpoint=None, bert_type='bertbase'):
        super().__init__()
        self.device = device
        self.bert = Bert(bert_type=bert_type)
        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, d_ff=2048, heads=8, dropout=0.2, num_inter_layers=2
        )

        if checkpoint is not None:
            self.load_state_dict(checkpoint, strict=True)

        self.to(device)
Пример #5
0
def only_model(args, device_id):

    logger.info('Loading checkpoint from %s' % args.test_from)
    checkpoint = torch.load(args.test_from,
                            map_location=lambda storage, loc: storage)

    ### We load our ExtSummarizer model
    model = ExtSummarizer(args, device, checkpoint)
    model.eval()

    ### We create an encoder and a decoder like those of ExtSummarizer and load the latter parameters into the former
    ### This is for the test sake
    encoder = Bert(False, '/tmp', True)
    load_my_state_dict(encoder, checkpoint['model'])

    decoder = ExtTransformerEncoder(encoder.model.config.hidden_size,
                                    args.ext_ff_size, args.ext_heads,
                                    args.ext_dropout, args.ext_layers)
    load_my_state_dict_decoder(decoder, checkpoint['model'])

    encoder.eval()
    decoder.eval()

    seq_len = 250

    ### We test if the parameters have been well loaded
    input_ids = torch.tensor([np.random.randint(100, 15000, seq_len)],
                             dtype=torch.long)
    mask = torch.ones(1, seq_len, dtype=torch.float)
    clss = torch.tensor([[20, 36, 55, 100, 122, 130, 200, 222]],
                        dtype=torch.long)
    mask_cls = torch.tensor([[1] * len(clss[0])], dtype=torch.long)
    """## test encoder
    top_vec = model.bert(input_ids, mask)
    top_vec1 = encoder(input_ids, mask)
    logger.info((top_vec-top_vec1).sum())

    ## test decoder
    sents_vec = top_vec[torch.arange(top_vec.size(0)).unsqueeze(1), clss]
    sents_vec = sents_vec * mask_cls[:, :, None].float()

    sents_vec1 = top_vec1[torch.arange(top_vec1.size(0)).unsqueeze(1), clss]
    sents_vec1 = sents_vec1 * mask_cls[:, :, None].float()


    scores = model.ext_layer(sents_vec, mask_cls)
    scores1  = decoder(sents_vec1, mask_cls)
    logger.info((scores-scores1).sum())"""

    ################# ONNX ########################"

    ## Now we are exporting the encoder and the decoder into onnx
    """input_names = ["input_ids", "mask"]
    output_names = ["hidden_outputs"]
    torch.onnx.export(model.bert.to('cpu'), (input_ids, mask), "/tmp/encoder5.onnx", verbose=True, 
                      input_names=input_names, output_names=output_names, export_params=True, keep_initializers_as_inputs=True)"""

    k_model = pytorch_to_keras(model.bert.to('cpu'), [input_ids, mask], [(
        1,
        250,
    ), (
        1,
        250,
    )],
                               verbose=True)

    print("okkk")
    """logger.info("Load onnx and test")
Пример #6
0
    def __init__(self, args, device, checkpoint):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.model_name, args.pretrained_name, args.temp_dir,
                         args.finetune_bert)

        self.ext_layer = ExtTransformerEncoder(
            self.bert.model.config.hidden_size, args.ext_ff_size,
            args.ext_heads, args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            '''#without random initialization
            if args.model_name == 'bert':
                from transformers import BertModel,BertConfig
                bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
                self.bert.model = BertModel(bert_config)
            elif args.model_name == 'xlnet':
                from transformers import XLNetModel,XLNetConfig
                xlnet_config = XLNetConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
                self.bert.model = XLNetModel(xlnet_config)
            elif args.model_name == 'roberta':
                from transformers import RobertaModel, RobertaConfig
                roberta_config = RobertaConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
                self.bert.model = RobertaModel(roberta_config)
            elif args.model_name == 'bert_lstm':
                from transformers import BertLSTMModel,BertLSTMConfig
                bert_config = BertLSTMConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads,
                                             intermediate_size=args.ext_ff_size, lstm_layer=args.lstm_layer)
                self.bert.model = BertLSTMModel(bert_config)
            '''
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if args.model_name == 'bert':
            if (args.max_pos > args.max_model_pos):
                if args.bert_baseline == 1:
                    args.max_pos = args.max_model_pos
                    self.bert.model.config.max_position_embeddings = args.max_model_pos
                else:
                    my_pos_embeddings = nn.Embedding(
                        args.max_pos, self.bert.model.config.hidden_size)
                    offset = 0
                    while offset < args.max_pos:
                        if offset + args.max_model_pos < args.max_pos:
                            my_pos_embeddings.weight.data[offset:offset+args.max_model_pos] \
                                = self.bert.model.embeddings.position_embeddings.weight.data[:args.max_model_pos].contiguous()
                        else:
                            my_pos_embeddings.weight.data[offset:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:]\
                                .repeat(args.max_pos-offset,1)
                        offset += args.max_model_pos
                    self.bert.model.embeddings.position_embeddings = my_pos_embeddings
                    self.bert.model.config.max_position_embeddings = args.max_pos
        elif args.model_name == 'bert_lstm':
            self.bert.model.config.max_position_embeddings = args.max_model_pos
            #embedding:self.max_position_embeddings = config.max_position_embeddings
            #layer self.chunk_size = config.max_position_embeddings
            self.bert.model.embeddings.max_position_embeddings = args.max_model_pos
            for layer_i in range(len(self.bert.model.encoder.layer)):
                self.bert.model.encoder.layer[
                    layer_i].chunk_size = args.max_model_pos

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)

        self.to(device)
Пример #7
0
    def __init__(self, args, device, checkpoint, lamb=0.8):
        super(ExtSummarizer, self).__init__()
        self.args = args
        self.device = device
        self.lamb = lamb
        # if args.
        # bert
        self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)

        # Extraction layer.
        self.ext_layer = ExtTransformerEncoder(self.bert.model.config.hidden_size, args.ext_ff_size, args.ext_heads,
                                               args.ext_dropout, args.ext_layers)
        if (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.ext_hidden_size,
                                     num_hidden_layers=args.ext_layers, num_attention_heads=args.ext_heads, intermediate_size=args.ext_ff_size)
            self.bert.model = BertModel(bert_config)
            self.ext_layer = Classifier(self.bert.model.config.hidden_size)

        if(args.max_pos>512):
            my_pos_embeddings = nn.Embedding(args.max_pos, self.bert.model.config.hidden_size)
            my_pos_embeddings.weight.data[:512] = self.bert.model.embeddings.position_embeddings.weight.data
            my_pos_embeddings.weight.data[512:] = self.bert.model.embeddings.position_embeddings.weight.data[-1][None,:].repeat(args.max_pos-512,1)
            self.bert.model.embeddings.position_embeddings = my_pos_embeddings

        # initial the parameter for infor\rel\novel.
        self.W_cont = nn.Parameter(torch.Tensor(1 ,self.bert.model.config.hidden_size))
        self.W_sim = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.Sim_layer= nn.Linear(self.bert.model.config.hidden_size,self.bert.model.config.hidden_size)
        self.W_rel = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.Rel_layer= nn.Linear(self.bert.model.config.hidden_size,self.bert.model.config.hidden_size)
        self.W_novel = nn.Parameter(torch.Tensor(self.bert.model.config.hidden_size, self.bert.model.config.hidden_size))
        self.b_matrix = nn.Parameter(torch.Tensor(1, 1))

        self.q_transform = nn.Linear(100, 1)
        self.bq = nn.Parameter(torch.Tensor(1, 1))
        self.brel = nn.Parameter(torch.Tensor(1, 1))
        self.bsim = nn.Parameter(torch.Tensor(1, 1))
        self.bcont = nn.Parameter(torch.Tensor(1, 1))

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            print("checkpoint loaded! ")
        else:
            if args.param_init != 0.0:
                for p in self.ext_layer.parameters():
                    p.data.uniform_(-args.param_init, args.param_init)
            if args.param_init_glorot:
                for p in self.ext_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
                for p in self.Rel_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
                for p in self.Sim_layer.parameters():
                    if p.dim() > 1:
                        xavier_uniform_(p)
            nn.init.xavier_uniform_(self.bq)
            nn.init.xavier_uniform_(self.W_cont)
            nn.init.xavier_uniform_(self.W_sim)
            nn.init.xavier_uniform_(self.W_rel)
            nn.init.xavier_uniform_(self.W_novel)
            nn.init.xavier_uniform_(self.b_matrix)
            nn.init.xavier_uniform_(self.bcont)
            nn.init.xavier_uniform_(self.brel)
            nn.init.xavier_uniform_(self.bsim)
        self.to(device)